From 5bf70b691f3a85f0221d744f9ad27c6bf1cdce6c Mon Sep 17 00:00:00 2001 From: Corentin Le Molgat Date: Fri, 16 Dec 2022 17:06:11 +0100 Subject: [PATCH] math_opt: export from google3 * CMake has not been updated yet * bazel was compiling at least last week bazel: disable math opt facility_location.py missing some dependencies... --- cmake/python.cmake | 38 +- ortools/math_opt/CMakeLists.txt | 2 + ortools/math_opt/README.md | 18 + ortools/math_opt/callback.proto | 12 +- .../constraints/indicator/BUILD.bazel | 4 +- .../indicator/indicator_constraint.cc | 12 +- .../indicator/indicator_constraint.h | 93 +- .../math_opt/constraints/indicator/storage.h | 5 + .../constraints/quadratic/BUILD.bazel | 3 + .../quadratic/quadratic_constraint.cc | 2 +- .../quadratic/quadratic_constraint.h | 106 +- .../math_opt/constraints/quadratic/storage.h | 6 + .../constraints/second_order_cone/BUILD.bazel | 1 + .../second_order_cone_constraint.cc | 8 +- .../second_order_cone_constraint.h | 26 +- .../constraints/second_order_cone/storage.h | 2 + ortools/math_opt/constraints/sos/BUILD.bazel | 3 + .../constraints/sos/sos1_constraint.cc | 4 +- .../constraints/sos/sos1_constraint.h | 32 +- .../constraints/sos/sos2_constraint.cc | 4 +- .../constraints/sos/sos2_constraint.h | 31 +- ortools/math_opt/constraints/sos/storage.h | 3 + .../math_opt/constraints/util/model_util.h | 2 +- ortools/math_opt/core/invalid_indicators.h | 1 + ortools/math_opt/core/inverted_bounds.h | 1 + ortools/math_opt/core/math_opt_proto_utils.cc | 1 + ortools/math_opt/core/python/CMakeLists.txt | 12 +- ortools/math_opt/core/python/solver_test.py | 4 +- ortools/math_opt/core/solver.cc | 4 +- ortools/math_opt/core/solver_debug.h | 3 +- ortools/math_opt/cpp/BUILD.bazel | 9 +- ortools/math_opt/cpp/basis_status.h | 2 - ortools/math_opt/cpp/callback.cc | 6 +- ortools/math_opt/cpp/callback.h | 20 +- .../compute_infeasible_subsystem_result.cc | 12 +- .../cpp/compute_infeasible_subsystem_result.h | 12 +- ortools/math_opt/cpp/key_types.h | 21 +- ortools/math_opt/cpp/linear_constraint.h | 96 +- ortools/math_opt/cpp/map_filter.h | 4 +- ortools/math_opt/cpp/matchers.cc | 3 +- ortools/math_opt/cpp/message_callback.cc | 11 +- ortools/math_opt/cpp/model.cc | 10 +- ortools/math_opt/cpp/model.h | 41 +- .../math_opt/cpp/model_solve_parameters.cc | 4 +- ortools/math_opt/cpp/model_solve_parameters.h | 4 +- ortools/math_opt/cpp/objective.cc | 8 +- ortools/math_opt/cpp/objective.h | 63 +- ortools/math_opt/cpp/parameters.cc | 5 +- ortools/math_opt/cpp/parameters.h | 3 +- ortools/math_opt/cpp/solution.cc | 17 +- ortools/math_opt/cpp/solution.h | 15 +- ortools/math_opt/cpp/solve_arguments.cc | 2 +- ortools/math_opt/cpp/solve_arguments.h | 2 +- ortools/math_opt/cpp/solve_impl.cc | 11 +- ortools/math_opt/cpp/solve_impl.h | 4 +- ortools/math_opt/cpp/solve_result.cc | 2 +- ortools/math_opt/cpp/solve_result.h | 2 +- ortools/math_opt/cpp/sparse_containers.cc | 26 +- ortools/math_opt/cpp/sparse_containers.h | 14 +- .../math_opt/cpp/variable_and_expressions.cc | 9 +- .../math_opt/cpp/variable_and_expressions.h | 213 +- ortools/math_opt/elemental/BUILD.bazel | 553 +++ ortools/math_opt/elemental/CMakeLists.txt | 30 + ortools/math_opt/elemental/README.md | 3 + ortools/math_opt/elemental/arrays.h | 74 + ortools/math_opt/elemental/arrays_test.cc | 179 + ortools/math_opt/elemental/attr_diff.h | 58 + ortools/math_opt/elemental/attr_diff_test.cc | 168 + ortools/math_opt/elemental/attr_key.h | 360 ++ ortools/math_opt/elemental/attr_key_test.cc | 335 ++ ortools/math_opt/elemental/attr_storage.h | 429 +++ .../math_opt/elemental/attr_storage_test.cc | 574 +++ ortools/math_opt/elemental/attributes.h | 349 ++ ortools/math_opt/elemental/attributes_test.cc | 58 + .../math_opt/elemental/codegen/BUILD.bazel | 81 + ortools/math_opt/elemental/codegen/codegen.cc | 52 + ortools/math_opt/elemental/codegen/gen.cc | 158 + ortools/math_opt/elemental/codegen/gen.h | 143 + ortools/math_opt/elemental/codegen/gen_c.cc | 245 ++ ortools/math_opt/elemental/codegen/gen_c.h | 32 + .../math_opt/elemental/codegen/gen_c_test.cc | 100 + .../math_opt/elemental/codegen/gen_python.cc | 161 + .../math_opt/elemental/codegen/gen_python.h | 30 + .../elemental/codegen/gen_python_test.cc | 60 + .../math_opt/elemental/codegen/gen_test.cc | 96 + ortools/math_opt/elemental/codegen/testing.h | 40 + ortools/math_opt/elemental/derived_data.h | 218 ++ .../math_opt/elemental/derived_data_test.cc | 172 + ortools/math_opt/elemental/diff.cc | 30 + ortools/math_opt/elemental/diff.h | 157 + ortools/math_opt/elemental/diff_test.cc | 289 ++ ortools/math_opt/elemental/element_diff.h | 64 + .../math_opt/elemental/element_diff_test.cc | 59 + .../math_opt/elemental/element_ref_tracker.h | 81 + .../elemental/element_ref_tracker_test.cc | 54 + ortools/math_opt/elemental/element_storage.cc | 56 + ortools/math_opt/elemental/element_storage.h | 192 + .../elemental/element_storage_test.cc | 170 + ortools/math_opt/elemental/elemental.cc | 169 + ortools/math_opt/elemental/elemental.h | 545 +++ .../elemental/elemental_differencer.cc | 253 ++ .../elemental/elemental_differencer.h | 233 ++ .../elemental/elemental_differencer_test.cc | 305 ++ .../elemental/elemental_export_model.cc | 1078 ++++++ .../elemental/elemental_export_model_test.cc | 506 +++ .../elemental_export_model_update_test.cc | 1251 +++++++ .../elemental/elemental_from_proto.cc | 476 +++ .../elemental_from_proto_fuzz_test.cc | 81 + .../elemental/elemental_from_proto_test.cc | 456 +++ .../math_opt/elemental/elemental_matcher.cc | 50 + .../math_opt/elemental/elemental_matcher.h | 53 + .../elemental/elemental_matcher_test.cc | 48 + ortools/math_opt/elemental/elemental_test.cc | 1201 +++++++ .../math_opt/elemental/elemental_to_string.cc | 168 + .../elemental/elemental_to_string_test.cc | 154 + .../elemental_update_from_proto_test.cc | 763 ++++ ortools/math_opt/elemental/elements.cc | 35 + ortools/math_opt/elemental/elements.h | 272 ++ ortools/math_opt/elemental/elements_test.cc | 88 + ortools/math_opt/elemental/python/BUILD.bazel | 84 + .../math_opt/elemental/python/CMakeLists.txt | 46 + .../elemental/python/cpp_elemental.pyi | 5 + .../math_opt/elemental/python/elemental.cc | 727 ++++ .../elemental/python/elemental_test.py | 543 +++ ortools/math_opt/elemental/python/enums.py.in | 79 + .../math_opt/elemental/python/enums_test.py | 30 + ortools/math_opt/elemental/safe_attr_ops.h | 73 + .../math_opt/elemental/safe_attr_ops_test.cc | 105 + ortools/math_opt/elemental/symmetry.h | 87 + ortools/math_opt/elemental/testing.h | 48 + .../math_opt/elemental/thread_safe_id_map.h | 277 ++ .../elemental/thread_safe_id_map_test.cc | 125 + ortools/math_opt/io/BUILD.bazel | 2 + ortools/math_opt/io/lp/BUILD.bazel | 2 + ortools/math_opt/io/lp/lp_model.h | 1 + ortools/math_opt/io/proto_converter.cc | 4 +- ortools/math_opt/io/python/CMakeLists.txt | 46 + ortools/math_opt/labs/BUILD.bazel | 2 + ortools/math_opt/parameters.proto | 5 +- ortools/math_opt/python/BUILD.bazel | 115 +- ortools/math_opt/python/CMakeLists.txt | 1 - .../math_opt/python/bounded_expressions.py | 182 + .../python/bounded_expressions_test.py | 83 + ortools/math_opt/python/callback.py | 42 +- ortools/math_opt/python/callback_test.py | 3 +- .../compute_infeasible_subsystem_result.py | 12 +- ...ompute_infeasible_subsystem_result_test.py | 2 - ortools/math_opt/python/elemental/BUILD.bazel | 31 + .../math_opt/python/elemental/elemental.py | 398 +++ ortools/math_opt/python/expressions.py | 32 +- ortools/math_opt/python/expressions_test.py | 29 +- ortools/math_opt/python/from_model.py | 37 + .../math_opt/python/indicator_constraints.py | 146 + .../python/indicator_constraints_test.py | 125 + .../python/ipc/proto_converter_test.py | 18 +- .../math_opt/python/ipc/remote_http_solve.py | 8 +- .../python/ipc/remote_http_solve_test.py | 6 +- ortools/math_opt/python/linear_constraints.py | 155 + .../math_opt/python/linear_expression_test.py | 1073 +++--- ortools/math_opt/python/mathopt.py | 82 +- ortools/math_opt/python/mathopt_test.py | 18 +- .../math_opt/python/message_callback_test.py | 2 - ortools/math_opt/python/model.py | 2713 ++++---------- ortools/math_opt/python/model_element_test.py | 270 ++ .../math_opt/python/model_objective_test.py | 336 ++ ortools/math_opt/python/model_parameters.py | 119 +- .../math_opt/python/model_parameters_test.py | 79 + .../python/model_quadratic_constraint_test.py | 217 ++ ortools/math_opt/python/model_test.py | 724 +--- ortools/math_opt/python/normalize_test.py | 30 +- .../math_opt/python/normalized_inequality.py | 287 ++ .../python/normalized_inequality_test.py | 293 ++ ortools/math_opt/python/objectives.py | 545 +++ ortools/math_opt/python/objectives_test.py | 544 +++ .../math_opt/python/quadratic_constraints.py | 177 + .../python/quadratic_constraints_test.py | 95 + ortools/math_opt/python/result.py | 196 +- ortools/math_opt/python/result_test.py | 81 + ortools/math_opt/python/solution.py | 121 +- ortools/math_opt/python/solution_test.py | 200 ++ ortools/math_opt/python/solve.py | 4 +- ortools/math_opt/python/solve_gurobi_test.py | 83 + ortools/math_opt/python/solve_test.py | 28 +- ortools/math_opt/python/sparse_containers.py | 53 +- .../math_opt/python/sparse_containers_test.py | 118 +- ortools/math_opt/python/statistics_test.py | 2 - ortools/math_opt/python/testing/BUILD.bazel | 26 +- .../math_opt/python/testing/compare_proto.py | 41 +- .../python/testing/compare_proto_test.py | 61 + .../python/testing/proto_matcher_test.py | 66 + ortools/math_opt/python/variables.py | 1418 ++++++++ ortools/math_opt/rpc.proto | 1 - ortools/math_opt/samples/cpp/BUILD.bazel | 2 + ortools/math_opt/samples/cpp/cocktail_hour.cc | 1 + .../samples/cpp/time_indexed_scheduling.cc | 2 +- ortools/math_opt/samples/python/BUILD.bazel | 35 +- .../math_opt/samples/python/CMakeLists.txt | 3 + .../samples/python/facility_location.py | 300 ++ .../samples/python/hierarchical_objectives.py | 71 + .../samples/python/smallest_circle.py | 104 + ortools/math_opt/solver_tests/BUILD.bazel | 3 + .../math_opt/solver_tests/base_solver_test.cc | 4 +- .../math_opt/solver_tests/callback_tests.cc | 3 - .../math_opt/solver_tests/generic_tests.cc | 1 + .../solver_tests/ip_parameter_tests.cc | 10 - .../lp_model_solve_parameters_tests.cc | 2 +- .../solver_tests/lp_parameter_tests.cc | 2 +- ortools/math_opt/solver_tests/status_tests.cc | 3 - .../math_opt/solver_tests/testdata/beavma.mps | 3170 ++++++++--------- ortools/math_opt/solvers/BUILD.bazel | 7 + ortools/math_opt/solvers/cp_sat_solver.cc | 2 +- ortools/math_opt/solvers/glop_solver.cc | 2 +- ortools/math_opt/solvers/glpk/BUILD.bazel | 3 + ortools/math_opt/solvers/glpk/rays.cc | 2 +- ortools/math_opt/solvers/glpk_solver.cc | 2 +- ortools/math_opt/solvers/glpk_solver.h | 2 +- ortools/math_opt/solvers/gscip/BUILD.bazel | 2 + ortools/math_opt/solvers/gscip_solver.cc | 2 +- ortools/math_opt/solvers/gurobi/BUILD.bazel | 2 + ortools/math_opt/solvers/gurobi/g_gurobi.cc | 1 + ortools/math_opt/solvers/gurobi_solver.cc | 4 +- ortools/math_opt/solvers/highs_solver.h | 3 + ortools/math_opt/solvers/pdlp_solver.cc | 2 +- ortools/math_opt/solvers/xpress/BUILD.bazel | 1 + ortools/math_opt/solvers/xpress/g_xpress.cc | 7 +- ortools/math_opt/solvers/xpress/g_xpress.h | 5 +- ortools/math_opt/storage/BUILD.bazel | 84 +- .../storage/atomic_constraint_storage.h | 3 +- .../math_opt/storage/atomic_constraints_v2.h | 146 + .../storage/linear_constraint_storage.cc | 3 +- ortools/math_opt/storage/model_storage.cc | 7 +- ortools/math_opt/storage/model_storage.h | 32 +- .../math_opt/storage/model_storage_item.cc | 37 + ortools/math_opt/storage/model_storage_item.h | 213 ++ .../math_opt/storage/model_storage_types.h | 7 +- ortools/math_opt/storage/model_storage_v2.cc | 160 + ortools/math_opt/storage/model_storage_v2.h | 1139 ++++++ ortools/math_opt/storage/objective_storage.cc | 3 +- ortools/math_opt/testing/BUILD.bazel | 2 + ortools/math_opt/tools/BUILD.bazel | 3 + ortools/math_opt/tools/file_format_flags.cc | 2 +- ortools/math_opt/tools/mathopt_convert.cc | 2 +- ortools/math_opt/validators/BUILD.bazel | 2 + .../validators/termination_validator.cc | 2 +- ortools/python/setup.py.in | 13 +- 245 files changed, 28953 insertions(+), 5680 deletions(-) create mode 100644 ortools/math_opt/README.md create mode 100644 ortools/math_opt/elemental/BUILD.bazel create mode 100644 ortools/math_opt/elemental/CMakeLists.txt create mode 100644 ortools/math_opt/elemental/README.md create mode 100644 ortools/math_opt/elemental/arrays.h create mode 100644 ortools/math_opt/elemental/arrays_test.cc create mode 100644 ortools/math_opt/elemental/attr_diff.h create mode 100644 ortools/math_opt/elemental/attr_diff_test.cc create mode 100644 ortools/math_opt/elemental/attr_key.h create mode 100644 ortools/math_opt/elemental/attr_key_test.cc create mode 100644 ortools/math_opt/elemental/attr_storage.h create mode 100644 ortools/math_opt/elemental/attr_storage_test.cc create mode 100644 ortools/math_opt/elemental/attributes.h create mode 100644 ortools/math_opt/elemental/attributes_test.cc create mode 100644 ortools/math_opt/elemental/codegen/BUILD.bazel create mode 100644 ortools/math_opt/elemental/codegen/codegen.cc create mode 100644 ortools/math_opt/elemental/codegen/gen.cc create mode 100644 ortools/math_opt/elemental/codegen/gen.h create mode 100644 ortools/math_opt/elemental/codegen/gen_c.cc create mode 100644 ortools/math_opt/elemental/codegen/gen_c.h create mode 100644 ortools/math_opt/elemental/codegen/gen_c_test.cc create mode 100644 ortools/math_opt/elemental/codegen/gen_python.cc create mode 100644 ortools/math_opt/elemental/codegen/gen_python.h create mode 100644 ortools/math_opt/elemental/codegen/gen_python_test.cc create mode 100644 ortools/math_opt/elemental/codegen/gen_test.cc create mode 100644 ortools/math_opt/elemental/codegen/testing.h create mode 100644 ortools/math_opt/elemental/derived_data.h create mode 100644 ortools/math_opt/elemental/derived_data_test.cc create mode 100644 ortools/math_opt/elemental/diff.cc create mode 100644 ortools/math_opt/elemental/diff.h create mode 100644 ortools/math_opt/elemental/diff_test.cc create mode 100644 ortools/math_opt/elemental/element_diff.h create mode 100644 ortools/math_opt/elemental/element_diff_test.cc create mode 100644 ortools/math_opt/elemental/element_ref_tracker.h create mode 100644 ortools/math_opt/elemental/element_ref_tracker_test.cc create mode 100644 ortools/math_opt/elemental/element_storage.cc create mode 100644 ortools/math_opt/elemental/element_storage.h create mode 100644 ortools/math_opt/elemental/element_storage_test.cc create mode 100644 ortools/math_opt/elemental/elemental.cc create mode 100644 ortools/math_opt/elemental/elemental.h create mode 100644 ortools/math_opt/elemental/elemental_differencer.cc create mode 100644 ortools/math_opt/elemental/elemental_differencer.h create mode 100644 ortools/math_opt/elemental/elemental_differencer_test.cc create mode 100644 ortools/math_opt/elemental/elemental_export_model.cc create mode 100644 ortools/math_opt/elemental/elemental_export_model_test.cc create mode 100644 ortools/math_opt/elemental/elemental_export_model_update_test.cc create mode 100644 ortools/math_opt/elemental/elemental_from_proto.cc create mode 100644 ortools/math_opt/elemental/elemental_from_proto_fuzz_test.cc create mode 100644 ortools/math_opt/elemental/elemental_from_proto_test.cc create mode 100644 ortools/math_opt/elemental/elemental_matcher.cc create mode 100644 ortools/math_opt/elemental/elemental_matcher.h create mode 100644 ortools/math_opt/elemental/elemental_matcher_test.cc create mode 100644 ortools/math_opt/elemental/elemental_test.cc create mode 100644 ortools/math_opt/elemental/elemental_to_string.cc create mode 100644 ortools/math_opt/elemental/elemental_to_string_test.cc create mode 100644 ortools/math_opt/elemental/elemental_update_from_proto_test.cc create mode 100644 ortools/math_opt/elemental/elements.cc create mode 100644 ortools/math_opt/elemental/elements.h create mode 100644 ortools/math_opt/elemental/elements_test.cc create mode 100644 ortools/math_opt/elemental/python/BUILD.bazel create mode 100644 ortools/math_opt/elemental/python/CMakeLists.txt create mode 100644 ortools/math_opt/elemental/python/cpp_elemental.pyi create mode 100644 ortools/math_opt/elemental/python/elemental.cc create mode 100644 ortools/math_opt/elemental/python/elemental_test.py create mode 100755 ortools/math_opt/elemental/python/enums.py.in create mode 100644 ortools/math_opt/elemental/python/enums_test.py create mode 100644 ortools/math_opt/elemental/safe_attr_ops.h create mode 100644 ortools/math_opt/elemental/safe_attr_ops_test.cc create mode 100644 ortools/math_opt/elemental/symmetry.h create mode 100644 ortools/math_opt/elemental/testing.h create mode 100644 ortools/math_opt/elemental/thread_safe_id_map.h create mode 100644 ortools/math_opt/elemental/thread_safe_id_map_test.cc create mode 100644 ortools/math_opt/io/python/CMakeLists.txt create mode 100644 ortools/math_opt/python/bounded_expressions.py create mode 100644 ortools/math_opt/python/bounded_expressions_test.py create mode 100644 ortools/math_opt/python/elemental/BUILD.bazel create mode 100644 ortools/math_opt/python/elemental/elemental.py create mode 100644 ortools/math_opt/python/from_model.py create mode 100644 ortools/math_opt/python/indicator_constraints.py create mode 100644 ortools/math_opt/python/indicator_constraints_test.py create mode 100644 ortools/math_opt/python/linear_constraints.py create mode 100644 ortools/math_opt/python/model_element_test.py create mode 100644 ortools/math_opt/python/model_objective_test.py create mode 100644 ortools/math_opt/python/model_quadratic_constraint_test.py create mode 100644 ortools/math_opt/python/normalized_inequality.py create mode 100644 ortools/math_opt/python/normalized_inequality_test.py create mode 100644 ortools/math_opt/python/objectives.py create mode 100644 ortools/math_opt/python/objectives_test.py create mode 100644 ortools/math_opt/python/quadratic_constraints.py create mode 100644 ortools/math_opt/python/quadratic_constraints_test.py create mode 100644 ortools/math_opt/python/testing/compare_proto_test.py create mode 100644 ortools/math_opt/python/testing/proto_matcher_test.py create mode 100644 ortools/math_opt/python/variables.py create mode 100644 ortools/math_opt/samples/python/facility_location.py create mode 100644 ortools/math_opt/samples/python/hierarchical_objectives.py create mode 100644 ortools/math_opt/samples/python/smallest_circle.py create mode 100644 ortools/math_opt/storage/atomic_constraints_v2.h create mode 100644 ortools/math_opt/storage/model_storage_item.cc create mode 100644 ortools/math_opt/storage/model_storage_item.h create mode 100644 ortools/math_opt/storage/model_storage_v2.cc create mode 100644 ortools/math_opt/storage/model_storage_v2.h diff --git a/cmake/python.cmake b/cmake/python.cmake index 5b4ecc2461..9d37cad5c1 100644 --- a/cmake/python.cmake +++ b/cmake/python.cmake @@ -298,6 +298,8 @@ endforeach() if(BUILD_MATH_OPT) add_subdirectory(ortools/math_opt/core/python) + add_subdirectory(ortools/math_opt/elemental/python) + add_subdirectory(ortools/math_opt/io/python) add_subdirectory(ortools/math_opt/python) endif() @@ -331,7 +333,12 @@ if(BUILD_MATH_OPT) file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/__init__.py CONTENT "") file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/core/__init__.py CONTENT "") file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/core/python/__init__.py CONTENT "") + file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/elemental/__init__.py CONTENT "") + file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/elemental/python/__init__.py CONTENT "") + file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/io/__init__.py CONTENT "") + file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/io/python/__init__.py CONTENT "") file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/__init__.py CONTENT "") + file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/elemental/__init__.py CONTENT "") file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/ipc/__init__.py CONTENT "") file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/testing/__init__.py CONTENT "") file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/solvers/__init__.py CONTENT "") @@ -366,27 +373,42 @@ file(COPY ortools/linear_solver/python/model_builder_numbers.py DESTINATION ${PYTHON_PROJECT_DIR}/linear_solver/python) if(BUILD_MATH_OPT) + configure_file( + ortools/math_opt/elemental/python/enums.py.in + ${PYTHON_PROJECT_DIR}/math_opt/elemental/python/enums.py + COPYONLY) file(COPY + ortools/math_opt/python/bounded_expressions.py ortools/math_opt/python/callback.py ortools/math_opt/python/compute_infeasible_subsystem_result.py ortools/math_opt/python/errors.py ortools/math_opt/python/expressions.py + ortools/math_opt/python/from_model.py ortools/math_opt/python/hash_model_storage.py + ortools/math_opt/python/indicator_constraints.py ortools/math_opt/python/init_arguments.py + ortools/math_opt/python/linear_constraints.py ortools/math_opt/python/mathopt.py ortools/math_opt/python/message_callback.py ortools/math_opt/python/model.py ortools/math_opt/python/model_parameters.py ortools/math_opt/python/model_storage.py + ortools/math_opt/python/normalized_inequality.py ortools/math_opt/python/normalize.py + ortools/math_opt/python/objectives.py ortools/math_opt/python/parameters.py + ortools/math_opt/python/quadratic_constraints.py ortools/math_opt/python/result.py ortools/math_opt/python/solution.py ortools/math_opt/python/solve.py ortools/math_opt/python/solver_resources.py ortools/math_opt/python/sparse_containers.py ortools/math_opt/python/statistics.py + ortools/math_opt/python/variables.py DESTINATION ${PYTHON_PROJECT_DIR}/math_opt/python) + file(COPY + ortools/math_opt/python/elemental/elemental.py + DESTINATION ${PYTHON_PROJECT_DIR}/math_opt/python/elemental) file(COPY ortools/math_opt/python/ipc/proto_converter.py ortools/math_opt/python/ipc/remote_http_solve.py @@ -663,7 +685,13 @@ add_custom_command( $ ${PYTHON_PROJECT}/linear_solver/python COMMAND ${CMAKE_COMMAND} -E $,copy,true> - $ ${PYTHON_PROJECT}/math_opt/core/python + $ ${PYTHON_PROJECT}/math_opt/core/python + COMMAND ${CMAKE_COMMAND} -E + $,copy,true> + $ ${PYTHON_PROJECT}/math_opt/elemental/python + COMMAND ${CMAKE_COMMAND} -E + $,copy,true> + $ ${PYTHON_PROJECT}/math_opt/io/python COMMAND ${CMAKE_COMMAND} -E $,copy,true> $ ${PYTHON_PROJECT}/../pybind11_abseil @@ -697,7 +725,9 @@ add_custom_command( routing_pybind11 pywraplp model_builder_helper_pybind11 - math_opt_pybind11 + $<$:math_opt_core_pybind11> + $<$:math_opt_elemental_pybind11> + $<$:math_opt_io_pybind11> $ cp_model_helper_pybind11 rcpsp_pybind11 @@ -760,6 +790,10 @@ search_python_module( search_python_module( NAME wheel PACKAGE wheel) +search_python_module( + NAME typing_extensions + PACKAGE typing-extensions + NO_VERSION) add_custom_command( OUTPUT python/dist_timestamp diff --git a/ortools/math_opt/CMakeLists.txt b/ortools/math_opt/CMakeLists.txt index dffc756574..58e73959d9 100644 --- a/ortools/math_opt/CMakeLists.txt +++ b/ortools/math_opt/CMakeLists.txt @@ -18,6 +18,7 @@ endif() add_subdirectory(core) add_subdirectory(constraints) add_subdirectory(cpp) +add_subdirectory(elemental) add_subdirectory(io) add_subdirectory(labs) add_subdirectory(solver_tests) @@ -31,6 +32,7 @@ target_sources(${NAME} PUBLIC $ $ $ + $ $ $ $ diff --git a/ortools/math_opt/README.md b/ortools/math_opt/README.md new file mode 100644 index 0000000000..c5b948c4f9 --- /dev/null +++ b/ortools/math_opt/README.md @@ -0,0 +1,18 @@ +# math_opt + +The code in this directory provides a generic way of accessing mathematical +optimization solvers (sometimes called mathematical programming solvers), such +as GLOP, CP-SAT, SCIP and Gurobi. In particular, a single API is provided to +make these solvers largely interoperable. + +New code should prefer MathOpt to `MPSolver`, as defined in +[linear_solver.h](../linear_solver/linear_solver.h) +when possible. + +MathOpt has client libraries in C++, Python, and Java that most users should use +to build and solve their models. A proto API is also provided, but this is not +recommended for most users. + +See +[parameters.proto](../math_opt/parameters.proto?q=SolverTypeProto) +for the list of supported solvers. diff --git a/ortools/math_opt/callback.proto b/ortools/math_opt/callback.proto index c0134c130e..e604b101e5 100644 --- a/ortools/math_opt/callback.proto +++ b/ortools/math_opt/callback.proto @@ -158,7 +158,17 @@ message CallbackResultProto { bool is_lazy = 4; } - // Ends the solve early. + // When true it tells the solver to interrupt the solve as soon as possible. + // + // It can be set from any event. This is equivalent to using a + // SolveInterrupter and triggering it from the callback. + // + // Some solvers don't support interruption, in that case this is simply + // ignored and the solve terminates as usual. On top of that solvers may not + // immediately stop the solve. Thus the user should expect the callback to + // still be called after they set `terminate` to true in a previous + // call. Returning with `terminate` false after having previously returned + // true won't cancel the interruption. bool terminate = 1; // TODO(b/172214608): SCIP allows to reject a feasible solution without diff --git a/ortools/math_opt/constraints/indicator/BUILD.bazel b/ortools/math_opt/constraints/indicator/BUILD.bazel index 4eaad19944..9b3c1cfc6a 100644 --- a/ortools/math_opt/constraints/indicator/BUILD.bazel +++ b/ortools/math_opt/constraints/indicator/BUILD.bazel @@ -18,10 +18,11 @@ cc_library( srcs = ["indicator_constraint.cc"], hdrs = ["indicator_constraint.h"], deps = [ - "//ortools/base:intops", "//ortools/math_opt/constraints/util:model_util", "//ortools/math_opt/cpp:variable_and_expressions", + "//ortools/math_opt/elemental:elements", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "@abseil-cpp//absl/strings", ], ) @@ -51,6 +52,7 @@ cc_library( "//ortools/math_opt:model_update_cc_proto", "//ortools/math_opt:sparse_containers_cc_proto", "//ortools/math_opt/core:sorted", + "//ortools/math_opt/elemental:elements", "//ortools/math_opt/storage:atomic_constraint_storage", "//ortools/math_opt/storage:sparse_coefficient_map", "@abseil-cpp//absl/container:flat_hash_set", diff --git a/ortools/math_opt/constraints/indicator/indicator_constraint.cc b/ortools/math_opt/constraints/indicator/indicator_constraint.cc index 04a6884632..12860a0ae7 100644 --- a/ortools/math_opt/constraints/indicator/indicator_constraint.cc +++ b/ortools/math_opt/constraints/indicator/indicator_constraint.cc @@ -18,8 +18,6 @@ #include #include -#include "absl/strings/string_view.h" -#include "ortools/base/strong_int.h" #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/storage/model_storage.h" @@ -27,22 +25,22 @@ namespace operations_research::math_opt { BoundedLinearExpression IndicatorConstraint::ImpliedConstraint() const { - const IndicatorConstraintData& data = storage()->constraint_data(id_); + const IndicatorConstraintData& data = storage()->constraint_data(typed_id()); // NOTE: The following makes a copy of `data.linear_terms`. This can be made // more efficient if the need arises. LinearExpression expr = ToLinearExpression( - *storage_, {.coeffs = data.linear_terms, .offset = 0.0}); + *storage(), {.coeffs = data.linear_terms, .offset = 0.0}); return data.lower_bound <= std::move(expr) <= data.upper_bound; } std::string IndicatorConstraint::ToString() const { - if (!storage()->has_constraint(id_)) { + if (!storage()->has_constraint(typed_id())) { return std::string(kDeletedConstraintDefaultDescription); } - const IndicatorConstraintData& data = storage()->constraint_data(id_); + const IndicatorConstraintData& data = storage()->constraint_data(typed_id()); std::stringstream str; if (data.indicator.has_value()) { - str << Variable(storage_, *data.indicator) + str << Variable(storage(), *data.indicator) << (data.activate_on_zero ? " = 0" : " = 1"); } else { str << "[unset indicator variable]"; diff --git a/ortools/math_opt/constraints/indicator/indicator_constraint.h b/ortools/math_opt/constraints/indicator/indicator_constraint.h index 8f75f1df7f..9dace4bd7c 100644 --- a/ortools/math_opt/constraints/indicator/indicator_constraint.h +++ b/ortools/math_opt/constraints/indicator/indicator_constraint.h @@ -16,38 +16,28 @@ #ifndef OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_ #define OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_ -#include #include -#include #include #include #include "absl/strings/string_view.h" -#include "ortools/base/strong_int.h" #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" +#include "ortools/math_opt/elemental/elements.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" namespace operations_research::math_opt { // A value type that references an indicator constraint from ModelStorage. // Usually this type is passed by copy. -// -// This type implements https://abseil.io/docs/cpp/guides/hash. -class IndicatorConstraint { +class IndicatorConstraint final + : public ModelStorageElement { public: - // The typed integer used for ids. - using IdType = IndicatorConstraintId; + using ModelStorageElement::ModelStorageElement; - inline IndicatorConstraint(const ModelStorage* storage, - IndicatorConstraintId id); - - inline int64_t id() const; - - inline IndicatorConstraintId typed_id() const; - inline const ModelStorage* storage() const; - - inline absl::string_view name() const; + absl::string_view name() const; // Returns nullopt if the indicator variable is unset (this is a valid state, // in which the constraint is functionally ignored). @@ -65,91 +55,36 @@ class IndicatorConstraint { // Returns a detailed string description of the contents of the constraint // (not its name, use `<<` for that instead). std::string ToString() const; - - friend inline bool operator==(const IndicatorConstraint& lhs, - const IndicatorConstraint& rhs); - friend inline bool operator!=(const IndicatorConstraint& lhs, - const IndicatorConstraint& rhs); - template - friend H AbslHashValue(H h, const IndicatorConstraint& constraint); - friend std::ostream& operator<<(std::ostream& ostr, - const IndicatorConstraint& constraint); - - private: - const ModelStorage* storage_; - IndicatorConstraintId id_; }; -// Streams the name of the constraint, as registered upon constraint creation, -// or a short default if none was provided. -inline std::ostream& operator<<(std::ostream& ostr, - const IndicatorConstraint& constraint); - //////////////////////////////////////////////////////////////////////////////// // Inline function implementations //////////////////////////////////////////////////////////////////////////////// -int64_t IndicatorConstraint::id() const { return id_.value(); } - -IndicatorConstraintId IndicatorConstraint::typed_id() const { return id_; } - -const ModelStorage* IndicatorConstraint::storage() const { return storage_; } - -absl::string_view IndicatorConstraint::name() const { - if (storage_->has_constraint(id_)) { - return storage_->constraint_data(id_).name; +inline absl::string_view IndicatorConstraint::name() const { + if (storage()->has_constraint(typed_id())) { + return storage()->constraint_data(typed_id()).name; } return kDeletedConstraintDefaultDescription; } std::optional IndicatorConstraint::indicator_variable() const { const std::optional maybe_indicator = - storage_->constraint_data(id_).indicator; + storage()->constraint_data(typed_id()).indicator; if (!maybe_indicator.has_value()) { return std::nullopt; } - return Variable(storage_, *maybe_indicator); + return Variable(storage(), *maybe_indicator); } bool IndicatorConstraint::activate_on_zero() const { - return storage_->constraint_data(id_).activate_on_zero; + return storage()->constraint_data(typed_id()).activate_on_zero; } std::vector IndicatorConstraint::NonzeroVariables() const { - return AtomicConstraintNonzeroVariables(*storage_, id_); + return AtomicConstraintNonzeroVariables(*storage(), typed_id()); } -bool operator==(const IndicatorConstraint& lhs, - const IndicatorConstraint& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; -} - -bool operator!=(const IndicatorConstraint& lhs, - const IndicatorConstraint& rhs) { - return !(lhs == rhs); -} - -template -H AbslHashValue(H h, const IndicatorConstraint& constraint) { - return H::combine(std::move(h), constraint.id_.value(), constraint.storage_); -} - -std::ostream& operator<<(std::ostream& ostr, - const IndicatorConstraint& constraint) { - // TODO(b/170992529): handle quoting of invalid characters in the name. - const absl::string_view name = constraint.name(); - if (name.empty()) { - ostr << "__indic_con#" << constraint.id() << "__"; - } else { - ostr << name; - } - return ostr; -} - -IndicatorConstraint::IndicatorConstraint(const ModelStorage* const storage, - const IndicatorConstraintId id) - : storage_(storage), id_(id) {} - } // namespace operations_research::math_opt #endif // OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_ diff --git a/ortools/math_opt/constraints/indicator/storage.h b/ortools/math_opt/constraints/indicator/storage.h index f8c443b853..5e056ff634 100644 --- a/ortools/math_opt/constraints/indicator/storage.h +++ b/ortools/math_opt/constraints/indicator/storage.h @@ -19,6 +19,7 @@ #include #include +#include "ortools/math_opt/elemental/elements.h" #include "ortools/math_opt/model.pb.h" #include "ortools/math_opt/model_update.pb.h" #include "ortools/math_opt/storage/atomic_constraint_storage.h" @@ -34,6 +35,8 @@ struct IndicatorConstraintData { using IdType = IndicatorConstraintId; using ProtoType = IndicatorConstraintProto; using UpdatesProtoType = IndicatorConstraintUpdatesProto; + static constexpr ElementType kElementType = ElementType::kIndicatorConstraint; + static constexpr bool kSupportsElemental = true; // The `in_proto` must be in a valid state; see the inline comments on // `IndicatorConstraintProto` for details. @@ -55,6 +58,8 @@ struct IndicatorConstraintData { template <> struct AtomicConstraintTraits { using ConstraintData = IndicatorConstraintData; + static constexpr ElementType kElementType = ElementType::kIndicatorConstraint; + static constexpr bool kSupportsElemental = true; }; } // namespace operations_research::math_opt diff --git a/ortools/math_opt/constraints/quadratic/BUILD.bazel b/ortools/math_opt/constraints/quadratic/BUILD.bazel index 0a3e36eb39..1a395834e2 100644 --- a/ortools/math_opt/constraints/quadratic/BUILD.bazel +++ b/ortools/math_opt/constraints/quadratic/BUILD.bazel @@ -22,7 +22,9 @@ cc_library( "//ortools/math_opt/constraints/util:model_util", "//ortools/math_opt/cpp:key_types", "//ortools/math_opt/cpp:variable_and_expressions", + "//ortools/math_opt/elemental:elements", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:sparse_coefficient_map", "//ortools/math_opt/storage:sparse_matrix", "@abseil-cpp//absl/log:check", @@ -54,6 +56,7 @@ cc_library( "//ortools/math_opt:model_cc_proto", "//ortools/math_opt:model_update_cc_proto", "//ortools/math_opt:sparse_containers_cc_proto", + "//ortools/math_opt/elemental:elements", "//ortools/math_opt/storage:atomic_constraint_storage", "//ortools/math_opt/storage:model_storage_types", "//ortools/math_opt/storage:sparse_coefficient_map", diff --git a/ortools/math_opt/constraints/quadratic/quadratic_constraint.cc b/ortools/math_opt/constraints/quadratic/quadratic_constraint.cc index d5e4efa095..f75f42c53c 100644 --- a/ortools/math_opt/constraints/quadratic/quadratic_constraint.cc +++ b/ortools/math_opt/constraints/quadratic/quadratic_constraint.cc @@ -26,7 +26,7 @@ namespace operations_research::math_opt { BoundedQuadraticExpression QuadraticConstraint::AsBoundedQuadraticExpression() const { QuadraticExpression expression; - const QuadraticConstraintData& data = storage()->constraint_data(id_); + const QuadraticConstraintData& data = storage()->constraint_data(typed_id()); for (const auto [var, coeff] : data.linear_terms.terms()) { expression += coeff * Variable(storage(), var); } diff --git a/ortools/math_opt/constraints/quadratic/quadratic_constraint.h b/ortools/math_opt/constraints/quadratic/quadratic_constraint.h index f2bc13a829..f19a0c6b19 100644 --- a/ortools/math_opt/constraints/quadratic/quadratic_constraint.h +++ b/ortools/math_opt/constraints/quadratic/quadratic_constraint.h @@ -18,19 +18,18 @@ #ifndef OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_ #define OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_ -#include -#include #include #include #include #include "absl/log/check.h" #include "absl/strings/string_view.h" -#include "ortools/base/strong_int.h" #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/key_types.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" +#include "ortools/math_opt/elemental/elements.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" #include "ortools/math_opt/storage/sparse_coefficient_map.h" #include "ortools/math_opt/storage/sparse_matrix.h" @@ -38,20 +37,11 @@ namespace operations_research::math_opt { // A value type that references a quadratic constraint from ModelStorage. // Usually this type is passed by copy. -// -// This type implements https://abseil.io/docs/cpp/guides/hash. -class QuadraticConstraint { +class QuadraticConstraint final + : public ModelStorageElement { public: - // The typed integer used for ids. - using IdType = QuadraticConstraintId; - - inline QuadraticConstraint(const ModelStorage* storage, - QuadraticConstraintId id); - - inline int64_t id() const; - - inline QuadraticConstraintId typed_id() const; - inline const ModelStorage* storage() const; + using ModelStorageElement::ModelStorageElement; inline double lower_bound() const; inline double upper_bound() const; @@ -89,47 +79,23 @@ class QuadraticConstraint { // Returns a detailed string description of the contents of the constraint // (not its name, use `<<` for that instead). inline std::string ToString() const; - - friend inline bool operator==(const QuadraticConstraint& lhs, - const QuadraticConstraint& rhs); - friend inline bool operator!=(const QuadraticConstraint& lhs, - const QuadraticConstraint& rhs); - template - friend H AbslHashValue(H h, const QuadraticConstraint& quadratic_constraint); - friend std::ostream& operator<<( - std::ostream& ostr, const QuadraticConstraint& quadratic_constraint); - - private: - const ModelStorage* storage_; - QuadraticConstraintId id_; }; -// Streams the name of the constraint, as registered upon constraint creation, -// or a short default if none was provided. -inline std::ostream& operator<<(std::ostream& ostr, - const QuadraticConstraint& constraint); - //////////////////////////////////////////////////////////////////////////////// // Inline function implementations //////////////////////////////////////////////////////////////////////////////// -int64_t QuadraticConstraint::id() const { return id_.value(); } - -QuadraticConstraintId QuadraticConstraint::typed_id() const { return id_; } - -const ModelStorage* QuadraticConstraint::storage() const { return storage_; } - double QuadraticConstraint::lower_bound() const { - return storage_->constraint_data(id_).lower_bound; + return storage()->constraint_data(typed_id()).lower_bound; } double QuadraticConstraint::upper_bound() const { - return storage_->constraint_data(id_).upper_bound; + return storage()->constraint_data(typed_id()).upper_bound; } absl::string_view QuadraticConstraint::name() const { - if (storage_->has_constraint(id_)) { - return storage_->constraint_data(id_).name; + if (storage()->has_constraint(typed_id())) { + return storage()->constraint_data(typed_id()).name; } return kDeletedConstraintDefaultDescription; } @@ -145,27 +111,31 @@ bool QuadraticConstraint::is_quadratic_coefficient_nonzero( } double QuadraticConstraint::linear_coefficient(const Variable variable) const { - CHECK_EQ(variable.storage(), storage_) + CHECK_EQ(variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->constraint_data(id_).linear_terms.get(variable.typed_id()); + return storage() + ->constraint_data(typed_id()) + .linear_terms.get(variable.typed_id()); } double QuadraticConstraint::quadratic_coefficient( const Variable first_variable, const Variable second_variable) const { - CHECK_EQ(first_variable.storage(), storage_) + CHECK_EQ(first_variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - CHECK_EQ(second_variable.storage(), storage_) + CHECK_EQ(second_variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->constraint_data(id_).quadratic_terms.get( - first_variable.typed_id(), second_variable.typed_id()); + return storage() + ->constraint_data(typed_id()) + .quadratic_terms.get(first_variable.typed_id(), + second_variable.typed_id()); } std::vector QuadraticConstraint::NonzeroVariables() const { - return AtomicConstraintNonzeroVariables(*storage_, id_); + return AtomicConstraintNonzeroVariables(*storage(), typed_id()); } std::string QuadraticConstraint::ToString() const { - if (!storage()->has_constraint(id_)) { + if (!storage()->has_constraint(typed_id())) { return std::string(kDeletedConstraintDefaultDescription); } std::stringstream str; @@ -173,38 +143,6 @@ std::string QuadraticConstraint::ToString() const { return str.str(); } -bool operator==(const QuadraticConstraint& lhs, - const QuadraticConstraint& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; -} - -bool operator!=(const QuadraticConstraint& lhs, - const QuadraticConstraint& rhs) { - return !(lhs == rhs); -} - -template -H AbslHashValue(H h, const QuadraticConstraint& quadratic_constraint) { - return H::combine(std::move(h), quadratic_constraint.id_.value(), - quadratic_constraint.storage_); -} - -std::ostream& operator<<(std::ostream& ostr, - const QuadraticConstraint& constraint) { - // TODO(b/170992529): handle quoting of invalid characters in the name. - const absl::string_view name = constraint.name(); - if (name.empty()) { - ostr << "__quad_con#" << constraint.id() << "__"; - } else { - ostr << name; - } - return ostr; -} - -QuadraticConstraint::QuadraticConstraint(const ModelStorage* const storage, - const QuadraticConstraintId id) - : storage_(storage), id_(id) {} - } // namespace operations_research::math_opt #endif // OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_ diff --git a/ortools/math_opt/constraints/quadratic/storage.h b/ortools/math_opt/constraints/quadratic/storage.h index 15bfa082d7..162eff2628 100644 --- a/ortools/math_opt/constraints/quadratic/storage.h +++ b/ortools/math_opt/constraints/quadratic/storage.h @@ -18,6 +18,7 @@ #include #include +#include "ortools/math_opt/elemental/elements.h" #include "ortools/math_opt/model.pb.h" #include "ortools/math_opt/model_update.pb.h" #include "ortools/math_opt/storage/atomic_constraint_storage.h" @@ -36,6 +37,9 @@ struct QuadraticConstraintData { using ProtoType = QuadraticConstraintProto; using UpdatesProtoType = QuadraticConstraintUpdatesProto; + static constexpr ElementType kElementType = ElementType::kQuadraticConstraint; + static constexpr bool kSupportsElemental = true; + // The `in_proto` must be in a valid state; see the inline comments on // `QuadraticConstraintProto` for details. static QuadraticConstraintData FromProto(const ProtoType& in_proto); @@ -53,6 +57,8 @@ struct QuadraticConstraintData { template <> struct AtomicConstraintTraits { using ConstraintData = QuadraticConstraintData; + static constexpr ElementType kElementType = ElementType::kQuadraticConstraint; + static constexpr bool kSupportsElemental = true; }; } // namespace operations_research::math_opt diff --git a/ortools/math_opt/constraints/second_order_cone/BUILD.bazel b/ortools/math_opt/constraints/second_order_cone/BUILD.bazel index 59b63e6094..7f564322ac 100644 --- a/ortools/math_opt/constraints/second_order_cone/BUILD.bazel +++ b/ortools/math_opt/constraints/second_order_cone/BUILD.bazel @@ -24,6 +24,7 @@ cc_library( "//ortools/math_opt/cpp:variable_and_expressions", "//ortools/math_opt/storage:linear_expression_data", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:model_storage_types", "@abseil-cpp//absl/strings", ], diff --git a/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.cc b/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.cc index 04256b3198..5b35f54960 100644 --- a/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.cc +++ b/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.cc @@ -30,7 +30,7 @@ namespace operations_research::math_opt { LinearExpression SecondOrderConeConstraint::UpperBound() const { - return ToLinearExpression(*storage_, + return ToLinearExpression(*storage(), storage()->constraint_data(id_).upper_bound); } @@ -40,7 +40,7 @@ std::vector SecondOrderConeConstraint::ArgumentsToNorm() std::vector args; args.reserve(data.arguments_to_norm.size()); for (const LinearExpressionData& arg_data : data.arguments_to_norm) { - args.push_back(ToLinearExpression(*storage_, arg_data)); + args.push_back(ToLinearExpression(*storage(), arg_data)); } return args; } @@ -58,9 +58,9 @@ std::string SecondOrderConeConstraint::ToString() const { str << ", "; } leading_comma = true; - str << ToLinearExpression(*storage_, arg_data); + str << ToLinearExpression(*storage(), arg_data); } - str << "}||₂ ≤ " << ToLinearExpression(*storage_, data.upper_bound); + str << "}||₂ ≤ " << ToLinearExpression(*storage(), data.upper_bound); return str.str(); } diff --git a/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.h b/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.h index 22542943d5..edc9bb53e2 100644 --- a/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.h +++ b/ortools/math_opt/constraints/second_order_cone/second_order_cone_constraint.h @@ -17,7 +17,6 @@ #define OR_TOOLS_MATH_OPT_CONSTRAINTS_SECOND_ORDER_CONE_SECOND_ORDER_CONE_CONSTRAINT_H_ #include -#include #include #include #include @@ -28,6 +27,7 @@ #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" #include "ortools/math_opt/storage/model_storage_types.h" namespace operations_research::math_opt { @@ -36,18 +36,17 @@ namespace operations_research::math_opt { // ModelStorage. Usually this type is passed by copy. // // This type implements https://abseil.io/docs/cpp/guides/hash. -class SecondOrderConeConstraint { +class SecondOrderConeConstraint final : public ModelStorageItem { public: // The typed integer used for ids. using IdType = SecondOrderConeConstraintId; - inline SecondOrderConeConstraint(const ModelStorage* storage, + inline SecondOrderConeConstraint(ModelStorageCPtr storage, SecondOrderConeConstraintId id); inline int64_t id() const; inline SecondOrderConeConstraintId typed_id() const; - inline const ModelStorage* storage() const; inline absl::string_view name() const; @@ -77,7 +76,6 @@ class SecondOrderConeConstraint { const SecondOrderConeConstraint& constraint); private: - const ModelStorage* storage_; SecondOrderConeConstraintId id_; }; @@ -96,24 +94,20 @@ SecondOrderConeConstraintId SecondOrderConeConstraint::typed_id() const { return id_; } -const ModelStorage* SecondOrderConeConstraint::storage() const { - return storage_; -} - absl::string_view SecondOrderConeConstraint::name() const { - if (storage_->has_constraint(id_)) { - return storage_->constraint_data(id_).name; + if (storage()->has_constraint(id_)) { + return storage()->constraint_data(id_).name; } return kDeletedConstraintDefaultDescription; } std::vector SecondOrderConeConstraint::NonzeroVariables() const { - return AtomicConstraintNonzeroVariables(*storage_, id_); + return AtomicConstraintNonzeroVariables(*storage(), id_); } bool operator==(const SecondOrderConeConstraint& lhs, const SecondOrderConeConstraint& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; + return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage(); } bool operator!=(const SecondOrderConeConstraint& lhs, @@ -123,7 +117,7 @@ bool operator!=(const SecondOrderConeConstraint& lhs, template H AbslHashValue(H h, const SecondOrderConeConstraint& constraint) { - return H::combine(std::move(h), constraint.id_.value(), constraint.storage_); + return H::combine(std::move(h), constraint.id_.value(), constraint.storage()); } std::ostream& operator<<(std::ostream& ostr, @@ -139,8 +133,8 @@ std::ostream& operator<<(std::ostream& ostr, } SecondOrderConeConstraint::SecondOrderConeConstraint( - const ModelStorage* const storage, const SecondOrderConeConstraintId id) - : storage_(storage), id_(id) {} + const ModelStorageCPtr storage, const SecondOrderConeConstraintId id) + : ModelStorageItem(storage), id_(id) {} } // namespace operations_research::math_opt diff --git a/ortools/math_opt/constraints/second_order_cone/storage.h b/ortools/math_opt/constraints/second_order_cone/storage.h index 054490d6fb..25d739bb20 100644 --- a/ortools/math_opt/constraints/second_order_cone/storage.h +++ b/ortools/math_opt/constraints/second_order_cone/storage.h @@ -34,6 +34,7 @@ struct SecondOrderConeConstraintData { using IdType = SecondOrderConeConstraintId; using ProtoType = SecondOrderConeConstraintProto; using UpdatesProtoType = SecondOrderConeConstraintUpdatesProto; + static constexpr bool kSupportsElemental = false; // The `in_proto` must be in a valid state; see the inline comments on // `SecondOrderConeConstraintProto` for details. @@ -50,6 +51,7 @@ struct SecondOrderConeConstraintData { template <> struct AtomicConstraintTraits { using ConstraintData = SecondOrderConeConstraintData; + static constexpr bool kSupportsElemental = false; }; } // namespace operations_research::math_opt diff --git a/ortools/math_opt/constraints/sos/BUILD.bazel b/ortools/math_opt/constraints/sos/BUILD.bazel index ef3b8ef098..e5c55a0630 100644 --- a/ortools/math_opt/constraints/sos/BUILD.bazel +++ b/ortools/math_opt/constraints/sos/BUILD.bazel @@ -24,7 +24,9 @@ cc_library( "//ortools/math_opt/cpp:variable_and_expressions", "//ortools/math_opt/storage:linear_expression_data", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:sparse_coefficient_map", + "@abseil-cpp//absl/base:nullability", "@abseil-cpp//absl/strings", ], ) @@ -57,6 +59,7 @@ cc_library( "//ortools/math_opt/cpp:variable_and_expressions", "//ortools/math_opt/storage:linear_expression_data", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:sparse_coefficient_map", "@abseil-cpp//absl/strings", ], diff --git a/ortools/math_opt/constraints/sos/sos1_constraint.cc b/ortools/math_opt/constraints/sos/sos1_constraint.cc index 413955c78d..3662246f46 100644 --- a/ortools/math_opt/constraints/sos/sos1_constraint.cc +++ b/ortools/math_opt/constraints/sos/sos1_constraint.cc @@ -23,10 +23,10 @@ namespace operations_research::math_opt { LinearExpression Sos1Constraint::Expression(int index) const { const LinearExpressionData& storage_expr = - storage_->constraint_data(id_).expression(index); + storage()->constraint_data(id_).expression(index); LinearExpression out_expr = storage_expr.offset; for (const auto [var_id, coeff] : storage_expr.coeffs.terms()) { - out_expr += coeff * Variable(storage_, var_id); + out_expr += coeff * Variable(storage(), var_id); } return out_expr; } diff --git a/ortools/math_opt/constraints/sos/sos1_constraint.h b/ortools/math_opt/constraints/sos/sos1_constraint.h index f36c326615..c89c097862 100644 --- a/ortools/math_opt/constraints/sos/sos1_constraint.h +++ b/ortools/math_opt/constraints/sos/sos1_constraint.h @@ -21,12 +21,14 @@ #include #include +#include "absl/base/nullability.h" #include "absl/strings/string_view.h" #include "ortools/base/strong_int.h" #include "ortools/math_opt/constraints/sos/util.h" #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" namespace operations_research::math_opt { @@ -34,17 +36,16 @@ namespace operations_research::math_opt { // Usually this type is passed by copy. // // This type implements https://abseil.io/docs/cpp/guides/hash. -class Sos1Constraint { +class Sos1Constraint final : public ModelStorageItem { public: // The typed integer used for ids. using IdType = Sos1ConstraintId; - inline Sos1Constraint(const ModelStorage* storage, Sos1ConstraintId id); + inline Sos1Constraint(ModelStorageCPtr storage, Sos1ConstraintId id); inline int64_t id() const; inline Sos1ConstraintId typed_id() const; - inline const ModelStorage* storage() const; inline int64_t num_expressions() const; LinearExpression Expression(int index) const; @@ -69,7 +70,6 @@ class Sos1Constraint { const Sos1Constraint& constraint); private: - const ModelStorage* storage_; Sos1ConstraintId id_; }; @@ -86,33 +86,31 @@ int64_t Sos1Constraint::id() const { return id_.value(); } Sos1ConstraintId Sos1Constraint::typed_id() const { return id_; } -const ModelStorage* Sos1Constraint::storage() const { return storage_; } - int64_t Sos1Constraint::num_expressions() const { - return storage_->constraint_data(id_).num_expressions(); + return storage()->constraint_data(id_).num_expressions(); } bool Sos1Constraint::has_weights() const { - return storage_->constraint_data(id_).has_weights(); + return storage()->constraint_data(id_).has_weights(); } double Sos1Constraint::weight(int index) const { - return storage_->constraint_data(id_).weight(index); + return storage()->constraint_data(id_).weight(index); } absl::string_view Sos1Constraint::name() const { - if (storage_->has_constraint(id_)) { - return storage_->constraint_data(id_).name(); + if (storage()->has_constraint(id_)) { + return storage()->constraint_data(id_).name(); } return kDeletedConstraintDefaultDescription; } std::vector Sos1Constraint::NonzeroVariables() const { - return AtomicConstraintNonzeroVariables(*storage_, id_); + return AtomicConstraintNonzeroVariables(*storage(), id_); } bool operator==(const Sos1Constraint& lhs, const Sos1Constraint& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; + return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage(); } bool operator!=(const Sos1Constraint& lhs, const Sos1Constraint& rhs) { @@ -121,7 +119,7 @@ bool operator!=(const Sos1Constraint& lhs, const Sos1Constraint& rhs) { template H AbslHashValue(H h, const Sos1Constraint& constraint) { - return H::combine(std::move(h), constraint.id_.value(), constraint.storage_); + return H::combine(std::move(h), constraint.id_.value(), constraint.storage()); } std::ostream& operator<<(std::ostream& ostr, const Sos1Constraint& constraint) { @@ -136,15 +134,15 @@ std::ostream& operator<<(std::ostream& ostr, const Sos1Constraint& constraint) { } std::string Sos1Constraint::ToString() const { - if (storage_->has_constraint(id_)) { + if (storage()->has_constraint(id_)) { return internal::SosConstraintToString(*this, "SOS1"); } return std::string(kDeletedConstraintDefaultDescription); } -Sos1Constraint::Sos1Constraint(const ModelStorage* const storage, +Sos1Constraint::Sos1Constraint(const ModelStorageCPtr storage, const Sos1ConstraintId id) - : storage_(storage), id_(id) {} + : ModelStorageItem(storage), id_(id) {} } // namespace operations_research::math_opt diff --git a/ortools/math_opt/constraints/sos/sos2_constraint.cc b/ortools/math_opt/constraints/sos/sos2_constraint.cc index c0a84741c5..310de47fc4 100644 --- a/ortools/math_opt/constraints/sos/sos2_constraint.cc +++ b/ortools/math_opt/constraints/sos/sos2_constraint.cc @@ -23,10 +23,10 @@ namespace operations_research::math_opt { LinearExpression Sos2Constraint::Expression(int index) const { const LinearExpressionData& storage_expr = - storage_->constraint_data(id_).expression(index); + storage()->constraint_data(id_).expression(index); LinearExpression out_expr = storage_expr.offset; for (const auto [var_id, coeff] : storage_expr.coeffs.terms()) { - out_expr += coeff * Variable(storage_, var_id); + out_expr += coeff * Variable(storage(), var_id); } return out_expr; } diff --git a/ortools/math_opt/constraints/sos/sos2_constraint.h b/ortools/math_opt/constraints/sos/sos2_constraint.h index 40d98c7a12..115abac05b 100644 --- a/ortools/math_opt/constraints/sos/sos2_constraint.h +++ b/ortools/math_opt/constraints/sos/sos2_constraint.h @@ -27,6 +27,7 @@ #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" namespace operations_research::math_opt { @@ -34,17 +35,16 @@ namespace operations_research::math_opt { // Usually this type is passed by copy. // // This type implements https://abseil.io/docs/cpp/guides/hash. -class Sos2Constraint { +class Sos2Constraint final : public ModelStorageItem { public: // The typed integer used for ids. using IdType = Sos2ConstraintId; - inline Sos2Constraint(const ModelStorage* storage, Sos2ConstraintId id); + inline Sos2Constraint(ModelStorageCPtr storage, Sos2ConstraintId id); inline int64_t id() const; inline Sos2ConstraintId typed_id() const; - inline const ModelStorage* storage() const; inline int64_t num_expressions() const; LinearExpression Expression(int index) const; @@ -70,7 +70,6 @@ class Sos2Constraint { const Sos2Constraint& constraint); private: - const ModelStorage* storage_; Sos2ConstraintId id_; }; @@ -87,33 +86,31 @@ int64_t Sos2Constraint::id() const { return id_.value(); } Sos2ConstraintId Sos2Constraint::typed_id() const { return id_; } -const ModelStorage* Sos2Constraint::storage() const { return storage_; } - int64_t Sos2Constraint::num_expressions() const { - return storage_->constraint_data(id_).num_expressions(); + return storage()->constraint_data(id_).num_expressions(); } bool Sos2Constraint::has_weights() const { - return storage_->constraint_data(id_).has_weights(); + return storage()->constraint_data(id_).has_weights(); } double Sos2Constraint::weight(int index) const { - return storage_->constraint_data(id_).weight(index); + return storage()->constraint_data(id_).weight(index); } absl::string_view Sos2Constraint::name() const { - if (storage_->has_constraint(id_)) { - return storage_->constraint_data(id_).name(); + if (storage()->has_constraint(id_)) { + return storage()->constraint_data(id_).name(); } return kDeletedConstraintDefaultDescription; } std::vector Sos2Constraint::NonzeroVariables() const { - return AtomicConstraintNonzeroVariables(*storage_, id_); + return AtomicConstraintNonzeroVariables(*storage(), id_); } bool operator==(const Sos2Constraint& lhs, const Sos2Constraint& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; + return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage(); } bool operator!=(const Sos2Constraint& lhs, const Sos2Constraint& rhs) { @@ -122,7 +119,7 @@ bool operator!=(const Sos2Constraint& lhs, const Sos2Constraint& rhs) { template H AbslHashValue(H h, const Sos2Constraint& constraint) { - return H::combine(std::move(h), constraint.id_.value(), constraint.storage_); + return H::combine(std::move(h), constraint.id_.value(), constraint.storage()); } std::ostream& operator<<(std::ostream& ostr, const Sos2Constraint& constraint) { @@ -137,15 +134,15 @@ std::ostream& operator<<(std::ostream& ostr, const Sos2Constraint& constraint) { } std::string Sos2Constraint::ToString() const { - if (storage_->has_constraint(id_)) { + if (storage()->has_constraint(id_)) { return internal::SosConstraintToString(*this, "SOS2"); } return std::string(kDeletedConstraintDefaultDescription); } -Sos2Constraint::Sos2Constraint(const ModelStorage* const storage, +Sos2Constraint::Sos2Constraint(const ModelStorageCPtr storage, const Sos2ConstraintId id) - : storage_(storage), id_(id) {} + : ModelStorageItem(storage), id_(id) {} } // namespace operations_research::math_opt diff --git a/ortools/math_opt/constraints/sos/storage.h b/ortools/math_opt/constraints/sos/storage.h index 1e21734ddf..f4e5b5e20a 100644 --- a/ortools/math_opt/constraints/sos/storage.h +++ b/ortools/math_opt/constraints/sos/storage.h @@ -43,6 +43,7 @@ class SosConstraintData { using IdType = ConstraintId; using ProtoType = SosConstraintProto; using UpdatesProtoType = SosConstraintUpdatesProto; + static constexpr bool kSupportsElemental = false; static_assert( std::disjunction_v, @@ -101,11 +102,13 @@ using Sos2ConstraintData = internal::SosConstraintData; template <> struct AtomicConstraintTraits { using ConstraintData = Sos1ConstraintData; + static constexpr bool kSupportsElemental = false; }; template <> struct AtomicConstraintTraits { using ConstraintData = Sos2ConstraintData; + static constexpr bool kSupportsElemental = false; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/ortools/math_opt/constraints/util/model_util.h b/ortools/math_opt/constraints/util/model_util.h index aaaec5909b..dd12306367 100644 --- a/ortools/math_opt/constraints/util/model_util.h +++ b/ortools/math_opt/constraints/util/model_util.h @@ -52,7 +52,7 @@ std::vector AtomicConstraintNonzeroVariables( } // Duck-types on `ConstraintType` having a typedef for the associated `IdType`, -// and having a `(const ModelStorage*, IdType)` constructor. +// and having a `(ModelStorageCPtr, IdType)` constructor. template std::vector AtomicConstraints(const ModelStorage& storage) { using IdType = typename ConstraintType::IdType; diff --git a/ortools/math_opt/core/invalid_indicators.h b/ortools/math_opt/core/invalid_indicators.h index 3539443635..38666e3699 100644 --- a/ortools/math_opt/core/invalid_indicators.h +++ b/ortools/math_opt/core/invalid_indicators.h @@ -14,6 +14,7 @@ #ifndef OR_TOOLS_MATH_OPT_CORE_INVALID_INDICATORS_H_ #define OR_TOOLS_MATH_OPT_CORE_INVALID_INDICATORS_H_ +#include #include #include diff --git a/ortools/math_opt/core/inverted_bounds.h b/ortools/math_opt/core/inverted_bounds.h index 48763613f1..678fb36047 100644 --- a/ortools/math_opt/core/inverted_bounds.h +++ b/ortools/math_opt/core/inverted_bounds.h @@ -14,6 +14,7 @@ #ifndef OR_TOOLS_MATH_OPT_CORE_INVERTED_BOUNDS_H_ #define OR_TOOLS_MATH_OPT_CORE_INVERTED_BOUNDS_H_ +#include #include #include diff --git a/ortools/math_opt/core/math_opt_proto_utils.cc b/ortools/math_opt/core/math_opt_proto_utils.cc index 2d59979ea6..ad93e55bc7 100644 --- a/ortools/math_opt/core/math_opt_proto_utils.cc +++ b/ortools/math_opt/core/math_opt_proto_utils.cc @@ -23,6 +23,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "ortools/base/logging.h" diff --git a/ortools/math_opt/core/python/CMakeLists.txt b/ortools/math_opt/core/python/CMakeLists.txt index dd2d82504c..90dd8e7670 100644 --- a/ortools/math_opt/core/python/CMakeLists.txt +++ b/ortools/math_opt/core/python/CMakeLists.txt @@ -11,29 +11,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -pybind11_add_module(math_opt_pybind11 MODULE solver.cc) -set_target_properties(math_opt_pybind11 PROPERTIES +pybind11_add_module(math_opt_core_pybind11 MODULE solver.cc) +set_target_properties(math_opt_core_pybind11 PROPERTIES LIBRARY_OUTPUT_NAME "solver") # note: macOS is APPLE and also UNIX ! if(APPLE) - set_target_properties(math_opt_pybind11 PROPERTIES + set_target_properties(math_opt_core_pybind11 PROPERTIES SUFFIX ".so" INSTALL_RPATH "@loader_path;@loader_path/../../../../${PYTHON_PROJECT}/.libs;@loader_path/../../../../pybind11_abseil") elseif(UNIX) - set_target_properties(math_opt_pybind11 PROPERTIES + set_target_properties(math_opt_core_pybind11 PROPERTIES INSTALL_RPATH "$ORIGIN:$ORIGIN/../../../../${PYTHON_PROJECT}/.libs:$ORIGIN/../../../../pybind11_abseil") endif() -target_link_libraries(math_opt_pybind11 PRIVATE +target_link_libraries(math_opt_core_pybind11 PRIVATE ${PROJECT_NAMESPACE}::ortools pybind11_abseil::absl_casters pybind11_abseil::status_casters pybind11_native_proto_caster protobuf::libprotobuf) -add_library(${PROJECT_NAMESPACE}::math_opt_pybind11 ALIAS math_opt_pybind11) +add_library(${PROJECT_NAMESPACE}::math_opt_core_pybind11 ALIAS math_opt_core_pybind11) if(BUILD_TESTING) file(GLOB PYTHON_SRCS "*_test.py") diff --git a/ortools/math_opt/core/python/solver_test.py b/ortools/math_opt/core/python/solver_test.py index 9894072d37..f74f98f0ef 100644 --- a/ortools/math_opt/core/python/solver_test.py +++ b/ortools/math_opt/core/python/solver_test.py @@ -17,7 +17,7 @@ import threading from typing import Callable, Optional, Sequence from absl.testing import absltest from absl.testing import parameterized -from pybind11_abseil.status import StatusNotOk +from pybind11_abseil import status from ortools.math_opt import callback_pb2 from ortools.math_opt import model_parameters_pb2 from ortools.math_opt import model_pb2 @@ -116,7 +116,7 @@ class PybindSolverTest(parameterized.TestCase): with self.assertRaisesRegex(RuntimeError, "id 7 not found"): _solve_model(model, use_solver_class=use_solver_class) else: - with self.assertRaisesRegex(StatusNotOk, "id 7 not found"): + with self.assertRaisesRegex(status.StatusNotOk, "id 7 not found"): _solve_model(model, use_solver_class=use_solver_class) @parameterized.named_parameters( diff --git a/ortools/math_opt/core/solver.cc b/ortools/math_opt/core/solver.cc index 9b22b168c7..382bb65823 100644 --- a/ortools/math_opt/core/solver.cc +++ b/ortools/math_opt/core/solver.cc @@ -49,7 +49,7 @@ namespace { // Returns an InternalError with the input status message if the input status is // not OK. -absl::Status ToInternalError(const absl::Status original) { +absl::Status ToInternalError(absl::Status original) { if (original.ok()) { return original; } @@ -201,7 +201,7 @@ Solver::ComputeInfeasibleSubsystem( RETURN_IF_ERROR(ValidateSolveParameters(arguments.parameters)) << "invalid parameters"; - ASSIGN_OR_RETURN(const ComputeInfeasibleSubsystemResultProto result, + ASSIGN_OR_RETURN(ComputeInfeasibleSubsystemResultProto result, underlying_solver_->ComputeInfeasibleSubsystem( arguments.parameters, arguments.message_callback, arguments.interrupter)); diff --git a/ortools/math_opt/core/solver_debug.h b/ortools/math_opt/core/solver_debug.h index aaf59048fe..4b3202a4bf 100644 --- a/ortools/math_opt/core/solver_debug.h +++ b/ortools/math_opt/core/solver_debug.h @@ -28,7 +28,8 @@ namespace internal { // This variable is intended to be used by MathOpt unit tests in other languages // to test the proper garbage collection. It should never be used in any other // context. -OR_DLL extern std::atomic debug_num_solver; +OR_DLL +extern std::atomic debug_num_solver; } // namespace internal } // namespace math_opt diff --git a/ortools/math_opt/cpp/BUILD.bazel b/ortools/math_opt/cpp/BUILD.bazel index 3a09813ec4..04799b7d48 100644 --- a/ortools/math_opt/cpp/BUILD.bazel +++ b/ortools/math_opt/cpp/BUILD.bazel @@ -94,7 +94,9 @@ cc_library( "//ortools/math_opt/storage:sparse_coefficient_map", "//ortools/math_opt/storage:sparse_matrix", "//ortools/util:fp_roundtrip_conv", + "@abseil-cpp//absl/base:nullability", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/log:die_if_null", "@abseil-cpp//absl/status", "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", @@ -113,6 +115,7 @@ cc_library( "//ortools/base:intops", "//ortools/base:map_util", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:model_storage_types", "//ortools/util:fp_roundtrip_conv", "@abseil-cpp//absl/base:core_headers", @@ -130,8 +133,8 @@ cc_library( deps = [ ":key_types", ":variable_and_expressions", - "//ortools/base:intops", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:model_storage_types", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/strings", @@ -144,10 +147,11 @@ cc_library( deps = [ ":key_types", ":variable_and_expressions", - "//ortools/base:intops", "//ortools/math_opt/constraints/util:model_util", "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "//ortools/math_opt/storage:model_storage_types", + "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/strings", ], @@ -259,6 +263,7 @@ cc_library( hdrs = ["key_types.h"], deps = [ "//ortools/math_opt/storage:model_storage", + "//ortools/math_opt/storage:model_storage_item", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/status", diff --git a/ortools/math_opt/cpp/basis_status.h b/ortools/math_opt/cpp/basis_status.h index 07b59ae336..ce0465f8fe 100644 --- a/ortools/math_opt/cpp/basis_status.h +++ b/ortools/math_opt/cpp/basis_status.h @@ -18,9 +18,7 @@ #define OR_TOOLS_MATH_OPT_CPP_BASIS_STATUS_H_ #include -#include -#include "absl/types/span.h" #include "ortools/math_opt/cpp/enums.h" // IWYU pragma: export #include "ortools/math_opt/solution.pb.h" diff --git a/ortools/math_opt/cpp/callback.cc b/ortools/math_opt/cpp/callback.cc index 62360b9311..3dcc2d2233 100644 --- a/ortools/math_opt/cpp/callback.cc +++ b/ortools/math_opt/cpp/callback.cc @@ -70,7 +70,7 @@ CallbackData::CallbackData(const CallbackEvent event, const absl::Duration runtime) : event(event), runtime(runtime) {} -CallbackData::CallbackData(const ModelStorage* storage, +CallbackData::CallbackData(const ModelStorageCPtr storage, const CallbackDataProto& proto) // iOS 11 does not support .value() hence we use operator* here and CHECK // below that we have a value. @@ -90,7 +90,7 @@ CallbackData::CallbackData(const ModelStorage* storage, } absl::Status CallbackRegistration::CheckModelStorage( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { RETURN_IF_ERROR(mip_node_filter.CheckModelStorage(expected_storage)) << "invalid mip_node_filter"; RETURN_IF_ERROR(mip_solution_filter.CheckModelStorage(expected_storage)) @@ -113,7 +113,7 @@ CallbackRegistrationProto CallbackRegistration::Proto() const { } absl::Status CallbackResult::CheckModelStorage( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { for (const GeneratedLinearConstraint& constraint : new_constraints) { RETURN_IF_ERROR( internal::CheckModelStorage(/*storage=*/constraint.storage(), diff --git a/ortools/math_opt/cpp/callback.h b/ortools/math_opt/cpp/callback.h index 686b1d0417..b0725a97b4 100644 --- a/ortools/math_opt/cpp/callback.h +++ b/ortools/math_opt/cpp/callback.h @@ -144,7 +144,7 @@ MATH_OPT_DEFINE_ENUM(CallbackEvent, CALLBACK_EVENT_UNSPECIFIED); struct CallbackRegistration { // Returns a failure if the referenced variables don't belong to the input // expected_storage (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // Returns the proto equivalent of this object. // @@ -189,7 +189,7 @@ struct CallbackData { // Users will typically not need this function. // Will CHECK fail if proto is not valid. - CallbackData(const ModelStorage* storage, const CallbackDataProto& proto); + CallbackData(ModelStorageCPtr storage, const CallbackDataProto& proto); // The current state of the underlying solver. CallbackEvent event; @@ -229,7 +229,7 @@ struct CallbackResult { BoundedLinearExpression linear_constraint; bool is_lazy = false; - const ModelStorage* storage() const { + NullableModelStorageCPtr storage() const { return linear_constraint.expression.storage(); } }; @@ -249,7 +249,7 @@ struct CallbackResult { // Returns a failure if the referenced variables don't belong to the input // expected_storage (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // Returns the proto equivalent of this object. // @@ -257,7 +257,17 @@ struct CallbackResult { // internal consistency of the referenced variables. CallbackResultProto Proto() const; - // Stop the solve process and return early. Can be called from any event. + // When true it tells the solver to interrupt the solve as soon as possible. + // + // It can be set from any event. This is equivalent to using a + // SolveInterrupter and triggering it from the callback. + // + // Some solvers don't support interruption, in that case this is simply + // ignored and the solve terminates as usual. On top of that solvers may not + // immediately stop the solve. Thus the user should expect the callback to + // still be called after they set `terminate` to true in a previous + // call. Returning with `terminate` false after having previously returned + // true won't cancel the interruption. bool terminate = false; // The user cuts and lazy constraints added. Prefer AddUserCut() and diff --git a/ortools/math_opt/cpp/compute_infeasible_subsystem_result.cc b/ortools/math_opt/cpp/compute_infeasible_subsystem_result.cc index 7490893adf..448fc4a566 100644 --- a/ortools/math_opt/cpp/compute_infeasible_subsystem_result.cc +++ b/ortools/math_opt/cpp/compute_infeasible_subsystem_result.cc @@ -76,7 +76,7 @@ template absl::Status BoundsMapProtoToCpp( const google::protobuf::Map& source, absl::flat_hash_map& target, - const ModelStorage* const model, + const ModelStorageCPtr model, bool (ModelStorage::* const contains_strong_id)(typename K::IdType id) const, const absl::string_view object_name) { @@ -95,7 +95,7 @@ absl::Status BoundsMapProtoToCpp( template absl::Status RepeatedIdsProtoToCpp( const google::protobuf::RepeatedField& source, - absl::flat_hash_set& target, const ModelStorage* const model, + absl::flat_hash_set& target, const ModelStorageCPtr model, bool (ModelStorage::* const contains_strong_id)(typename K::IdType id) const, const absl::string_view object_name) { @@ -134,7 +134,7 @@ google::protobuf::RepeatedField RepeatedIdsCppToProto( } // namespace absl::StatusOr ModelSubset::FromProto( - const ModelStorage* const model, const ModelSubsetProto& proto) { + const ModelStorageCPtr model, const ModelSubsetProto& proto) { ModelSubset model_subset; RETURN_IF_ERROR(BoundsMapProtoToCpp(proto.variable_bounds(), model_subset.variable_bounds, model, @@ -184,7 +184,7 @@ ModelSubsetProto ModelSubset::Proto() const { } absl::Status ModelSubset::CheckModelStorage( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { const auto validate_map_keys = [expected_storage](const auto& map, const absl::string_view name) -> absl::Status { @@ -348,7 +348,7 @@ std::ostream& operator<<(std::ostream& out, const ModelSubset& model_subset) { absl::StatusOr ComputeInfeasibleSubsystemResult::FromProto( - const ModelStorage* const model, + const ModelStorageCPtr model, const ComputeInfeasibleSubsystemResultProto& result_proto) { ComputeInfeasibleSubsystemResult result; const std::optional feasibility = @@ -383,7 +383,7 @@ ComputeInfeasibleSubsystemResultProto ComputeInfeasibleSubsystemResult::Proto() } absl::Status ComputeInfeasibleSubsystemResult::CheckModelStorage( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { return infeasible_subsystem.CheckModelStorage(expected_storage); } diff --git a/ortools/math_opt/cpp/compute_infeasible_subsystem_result.h b/ortools/math_opt/cpp/compute_infeasible_subsystem_result.h index c97c831b3a..f61dd2e238 100644 --- a/ortools/math_opt/cpp/compute_infeasible_subsystem_result.h +++ b/ortools/math_opt/cpp/compute_infeasible_subsystem_result.h @@ -64,7 +64,7 @@ struct ModelSubset { // // Returns an error when `model` does not contain a variable or constraint // associated with an index present in `proto`. - static absl::StatusOr FromProto(const ModelStorage* model, + static absl::StatusOr FromProto(ModelStorageCPtr model, const ModelSubsetProto& proto); // Returns the proto equivalent of this object. @@ -74,8 +74,8 @@ struct ModelSubset { ModelSubsetProto Proto() const; // Returns a failure if the `Variable` and Constraints contained in the fields - // do not belong to the input expected_storage (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + // do not belong to the input expected_storage. + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // True if this object corresponds to the empty subset. bool empty() const; @@ -105,7 +105,7 @@ struct ComputeInfeasibleSubsystemResult { // index present in `proto.infeasible_subsystem`. // * ValidateComputeInfeasibleSubsystemResultNoModel(result_proto) fails. static absl::StatusOr FromProto( - const ModelStorage* model, + ModelStorageCPtr model, const ComputeInfeasibleSubsystemResultProto& result_proto); // Returns the proto equivalent of this object. @@ -116,8 +116,8 @@ struct ComputeInfeasibleSubsystemResult { ComputeInfeasibleSubsystemResultProto Proto() const; // Returns a failure if this object contains references to a model other than - // `expected_storage` (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + // `expected_storage`. + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // The primal feasibility status of the model, as determined by the solver. FeasibilityStatus feasibility = FeasibilityStatus::kUndetermined; diff --git a/ortools/math_opt/cpp/key_types.h b/ortools/math_opt/cpp/key_types.h index e066bfff7b..5ceec1dd80 100644 --- a/ortools/math_opt/cpp/key_types.h +++ b/ortools/math_opt/cpp/key_types.h @@ -25,14 +25,16 @@ // // A key type K must match the following requirements: // - K::IdType is a value type used for indices. -// - K has a constructor K(const ModelStorage*, K::IdType). +// - K has a constructor K(ModelStorageCPtr, K::IdType). // - K is a value-semantic type. // - K has a function with signature `K::IdType K::typed_id() const`. -// - K has a function with signature `const ModelStorage* K::storage() const`. -// It must return a non-null pointer. +// - K has a function with signature `ModelStorageCPtr K::storage() const`. // - K::IdType is a valid key for absl::flat_hash_map or absl::flat_hash_set // (supports hash and ==). // - the is_key_type_v<> below should include them. +// TODO(b/396580721): Those requirements are those of `ModelStorageElement`. +// Once we've migrated most key types to `ModelStorageElement`, we should be +// able to simplify this code. #ifndef OR_TOOLS_MATH_OPT_CPP_KEY_TYPES_H_ #define OR_TOOLS_MATH_OPT_CPP_KEY_TYPES_H_ @@ -45,6 +47,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" namespace operations_research::math_opt { @@ -70,11 +73,9 @@ class Objective; // the values in the hash map are in the math_opt namespace. template constexpr inline bool is_key_type_v = - (std::is_same_v || std::is_same_v || - std::is_same_v || + (is_model_storage_element::value || std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || std::is_same_v); // Returns the keys of the map sorted by their (storage(), type_id()). @@ -162,12 +163,12 @@ inline constexpr absl::string_view kInputFromInvalidModelStorage = "the input does not belong to the same model"; // Returns a failure when the input pointer is not nullptr and points to a -// different model storage than expected_storage (which must not be nullptr). +// different model storage than expected_storage. // // Failure message is kInputFromInvalidModelStorage. -inline absl::Status CheckModelStorage( - const ModelStorage* const storage, - const ModelStorage* const expected_storage) { +inline absl::Status CheckModelStorage(const NullableModelStorageCPtr storage, + const ModelStorageCPtr expected_storage) { + // This is not allowed by the contract, but let's be safe. if (expected_storage == nullptr) { return absl::InternalError("expected_storage is nullptr"); } diff --git a/ortools/math_opt/cpp/linear_constraint.h b/ortools/math_opt/cpp/linear_constraint.h index 34837f0f6f..1d70c5f844 100644 --- a/ortools/math_opt/cpp/linear_constraint.h +++ b/ortools/math_opt/cpp/linear_constraint.h @@ -18,19 +18,18 @@ #ifndef OR_TOOLS_MATH_OPT_CPP_LINEAR_CONSTRAINT_H_ #define OR_TOOLS_MATH_OPT_CPP_LINEAR_CONSTRAINT_H_ -#include -#include #include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/strings/string_view.h" -#include "ortools/base/strong_int.h" #include "ortools/math_opt/constraints/util/model_util.h" #include "ortools/math_opt/cpp/key_types.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" #include "ortools/math_opt/storage/model_storage_types.h" namespace operations_research { @@ -38,19 +37,11 @@ namespace math_opt { // A value type that references a linear constraint from ModelStorage. Usually // this type is passed by copy. -// -// This type implements https://abseil.io/docs/cpp/guides/hash. -class LinearConstraint { +class LinearConstraint final + : public ModelStorageElement { public: - // The typed integer used for ids. - using IdType = LinearConstraintId; - - inline LinearConstraint(const ModelStorage* storage, LinearConstraintId id); - - inline int64_t id() const; - - inline LinearConstraintId typed_id() const; - inline const ModelStorage* storage() const; + using ModelStorageElement::ModelStorageElement; inline double lower_bound() const; inline double upper_bound() const; @@ -76,65 +67,42 @@ class LinearConstraint { // Returns a detailed string description of the contents of the constraint // (not its name, use `<<` for that instead). inline std::string ToString() const; - - friend inline bool operator==(const LinearConstraint& lhs, - const LinearConstraint& rhs); - friend inline bool operator!=(const LinearConstraint& lhs, - const LinearConstraint& rhs); - template - friend H AbslHashValue(H h, const LinearConstraint& linear_constraint); - friend std::ostream& operator<<(std::ostream& ostr, - const LinearConstraint& linear_constraint); - - private: - const ModelStorage* storage_; - LinearConstraintId id_; }; template using LinearConstraintMap = absl::flat_hash_map; -// Streams the name of the constraint, as registered upon constraint creation, -// or a short default if none was provided. -inline std::ostream& operator<<(std::ostream& ostr, - const LinearConstraint& linear_constraint); - //////////////////////////////////////////////////////////////////////////////// // Inline function implementations //////////////////////////////////////////////////////////////////////////////// -int64_t LinearConstraint::id() const { return id_.value(); } - -LinearConstraintId LinearConstraint::typed_id() const { return id_; } - -const ModelStorage* LinearConstraint::storage() const { return storage_; } - double LinearConstraint::lower_bound() const { - return storage_->linear_constraint_lower_bound(id_); + return storage()->linear_constraint_lower_bound(typed_id()); } double LinearConstraint::upper_bound() const { - return storage_->linear_constraint_upper_bound(id_); + return storage()->linear_constraint_upper_bound(typed_id()); } absl::string_view LinearConstraint::name() const { - if (storage()->has_linear_constraint(id_)) { - return storage_->linear_constraint_name(id_); + if (storage()->has_linear_constraint(typed_id())) { + return storage()->linear_constraint_name(typed_id()); } return kDeletedConstraintDefaultDescription; } bool LinearConstraint::is_coefficient_nonzero(const Variable variable) const { - CHECK_EQ(variable.storage(), storage_) + CHECK_EQ(variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->is_linear_constraint_coefficient_nonzero( - id_, variable.typed_id()); + return storage()->is_linear_constraint_coefficient_nonzero( + typed_id(), variable.typed_id()); } double LinearConstraint::coefficient(const Variable variable) const { - CHECK_EQ(variable.storage(), storage_) + CHECK_EQ(variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->linear_constraint_coefficient(id_, variable.typed_id()); + return storage()->linear_constraint_coefficient(typed_id(), + variable.typed_id()); } BoundedLinearExpression LinearConstraint::AsBoundedLinearExpression() const { @@ -150,7 +118,7 @@ BoundedLinearExpression LinearConstraint::AsBoundedLinearExpression() const { } std::string LinearConstraint::ToString() const { - if (!storage()->has_linear_constraint(id_)) { + if (!storage()->has_linear_constraint(typed_id())) { return std::string(kDeletedConstraintDefaultDescription); } std::stringstream str; @@ -158,36 +126,6 @@ std::string LinearConstraint::ToString() const { return str.str(); } -bool operator==(const LinearConstraint& lhs, const LinearConstraint& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; -} - -bool operator!=(const LinearConstraint& lhs, const LinearConstraint& rhs) { - return !(lhs == rhs); -} - -template -H AbslHashValue(H h, const LinearConstraint& linear_constraint) { - return H::combine(std::move(h), linear_constraint.id_.value(), - linear_constraint.storage_); -} - -std::ostream& operator<<(std::ostream& ostr, - const LinearConstraint& linear_constraint) { - // TODO(b/170992529): handle quoting of invalid characters in the name. - const absl::string_view name = linear_constraint.name(); - if (name.empty()) { - ostr << "__lin_con#" << linear_constraint.id() << "__"; - } else { - ostr << name; - } - return ostr; -} - -LinearConstraint::LinearConstraint(const ModelStorage* const storage, - const LinearConstraintId id) - : storage_(storage), id_(id) {} - } // namespace math_opt } // namespace operations_research diff --git a/ortools/math_opt/cpp/map_filter.h b/ortools/math_opt/cpp/map_filter.h index 0fb51eb4b9..dd2c622d44 100644 --- a/ortools/math_opt/cpp/map_filter.h +++ b/ortools/math_opt/cpp/map_filter.h @@ -100,7 +100,7 @@ struct MapFilter { // Returns a failure if the keys don't belong to the input expected_storage // (which must not be nullptr). inline absl::Status CheckModelStorage( - const ModelStorage* expected_storage) const; + ModelStorageCPtr expected_storage) const; // Returns the proto corresponding to this filter. // @@ -192,7 +192,7 @@ MapFilter MakeKeepKeysFilter(std::initializer_list keys) { template absl::Status MapFilter::CheckModelStorage( - const ModelStorage* expected_storage) const { + const ModelStorageCPtr expected_storage) const { if (!filtered_keys.has_value()) { return absl::OkStatus(); } diff --git a/ortools/math_opt/cpp/matchers.cc b/ortools/math_opt/cpp/matchers.cc index d049016e90..f426071801 100644 --- a/ortools/math_opt/cpp/matchers.cc +++ b/ortools/math_opt/cpp/matchers.cc @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "gtest/gtest.h" @@ -871,7 +872,7 @@ std::vector CompatibleReasons( } Matcher> CheckSolutions( - const std::vector& expected_solutions, + absl::Span expected_solutions, const SolveResultMatcherOptions& options) { if (options.first_solution_only && !expected_solutions.empty()) { return FirstElementIs( diff --git a/ortools/math_opt/cpp/message_callback.cc b/ortools/math_opt/cpp/message_callback.cc index 17255eaf09..992806acbf 100644 --- a/ortools/math_opt/cpp/message_callback.cc +++ b/ortools/math_opt/cpp/message_callback.cc @@ -37,7 +37,7 @@ class PrinterMessageCallbackImpl { const absl::string_view prefix) : output_stream_(output_stream), prefix_(prefix) {} - void Call(const std::vector& messages) { + void Call(absl::Span messages) { const absl::MutexLock lock(&mutex_); for (const std::string& message : messages) { output_stream_ << prefix_ << message << '\n'; @@ -56,7 +56,7 @@ void PushBack(absl::Span messages, sink->insert(sink->end(), messages.begin(), messages.end()); } -void PushBack(const std::vector& messages, +void PushBack(absl::Span messages, google::protobuf::RepeatedPtrField* const sink) { std::copy(messages.begin(), messages.end(), google::protobuf::RepeatedFieldBackInserter(sink)); @@ -68,7 +68,7 @@ class VectorLikeMessageCallbackImpl { explicit VectorLikeMessageCallbackImpl(Sink* const sink) : sink_(ABSL_DIE_IF_NULL(sink)) {} - void Call(const std::vector& messages) { + void Call(absl::Span messages) { const absl::MutexLock lock(&mutex_); PushBack(messages, sink_); } @@ -102,7 +102,7 @@ MessageCallback InfoLoggerMessageCallback(const absl::string_view prefix, MessageCallback VLoggerMessageCallback(int level, absl::string_view prefix, absl::SourceLocation loc) { - return [=](const std::vector& messages) { + return [=](absl::Span messages) { for (const std::string& message : messages) { VLOG(level).AtLocation(loc.file_name(), loc.line()) << prefix << message; } @@ -117,8 +117,7 @@ MessageCallback VectorMessageCallback(std::vector* sink) { const auto impl = std::make_shared>>( sink); - return - [=](const std::vector& messages) { impl->Call(messages); }; + return [=](absl::Span messages) { impl->Call(messages); }; } MessageCallback RepeatedPtrFieldMessageCallback( diff --git a/ortools/math_opt/cpp/model.cc b/ortools/math_opt/cpp/model.cc index 00420eaba8..12ea552d78 100644 --- a/ortools/math_opt/cpp/model.cc +++ b/ortools/math_opt/cpp/model.cc @@ -22,7 +22,9 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/check.h" +#include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -53,7 +55,7 @@ constexpr double kInf = std::numeric_limits::infinity(); absl::StatusOr> Model::FromModelProto( const ModelProto& model_proto) { - ASSIGN_OR_RETURN(std::unique_ptr storage, + ASSIGN_OR_RETURN(absl::Nonnull> storage, ModelStorage::FromModelProto(model_proto)); return std::make_unique(std::move(storage)); } @@ -61,10 +63,10 @@ absl::StatusOr> Model::FromModelProto( Model::Model(const absl::string_view name) : storage_(std::make_shared(name)) {} -Model::Model(std::unique_ptr storage) - : storage_(std::move(storage)) {} +Model::Model(absl::Nonnull> storage) + : storage_(ABSL_DIE_IF_NULL(std::move(storage))) {} -std::unique_ptr Model::Clone( +absl::Nonnull> Model::Clone( const std::optional new_name) const { return std::make_unique(storage_->Clone(new_name)); } diff --git a/ortools/math_opt/cpp/model.h b/ortools/math_opt/cpp/model.h index 37cca8f377..bb9939f098 100644 --- a/ortools/math_opt/cpp/model.h +++ b/ortools/math_opt/cpp/model.h @@ -23,6 +23,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -136,7 +137,7 @@ class Model { // This constructor is used when loading a model, for example from a // ModelProto or an MPS file. Note that in those cases the FromModelProto() // should be used. - explicit Model(std::unique_ptr storage); + explicit Model(absl::Nonnull> storage); Model(const Model&) = delete; Model& operator=(const Model&) = delete; @@ -158,7 +159,7 @@ class Model { // * in an arbitrary order using Variables() and LinearConstraints(). // // Note that the returned model does not have any update tracker. - std::unique_ptr Clone( + absl::Nonnull> Clone( std::optional new_name = std::nullopt) const; inline absl::string_view name() const; @@ -893,13 +894,13 @@ class Model { // // This API is for internal use only and regular users should have no need for // it. - const ModelStorage* storage() const { return storage_.get(); } + ModelStorageCPtr storage() const { return storage_.get(); } // Returns a pointer to the underlying model storage. // // This API is for internal use only and regular users should have no need for // it. - ModelStorage* storage() { return storage_.get(); } + ModelStoragePtr storage() { return storage_.get(); } // Prints the objective, the constraints and the variables of the model over // several lines in a human-readable way. Includes a new line at the end of @@ -911,12 +912,12 @@ class Model { // points to the same model as storage_. // // Use CheckModel() when nullptr is not a valid value. - inline void CheckOptionalModel(const ModelStorage* other_storage) const; + inline void CheckOptionalModel(NullableModelStorageCPtr other_storage) const; // Asserts (with CHECK) that the input pointer is the same as storage_. // // Use CheckOptionalModel() if nullptr is a valid value too. - inline void CheckModel(const ModelStorage* other_storage) const; + inline void CheckModel(ModelStorageCPtr other_storage) const; // Don't use storage_ directly; prefer to use storage() so that const member // functions don't have modifying access to the underlying storage. @@ -924,7 +925,7 @@ class Model { // We use a shared_ptr here so that the UpdateTracker class can have a // weak_ptr on the ModelStorage. This let it have a destructor that don't // crash when called after the destruction of the associated Model. - const std::shared_ptr storage_; + const absl::Nonnull> storage_; }; //////////////////////////////////////////////////////////////////////////////// @@ -973,7 +974,7 @@ int64_t Model::next_variable_id() const { } bool Model::has_variable(const int64_t id) const { - return has_variable(VariableId(id)); + return id < 0 ? false : has_variable(VariableId(id)); } bool Model::has_variable(const VariableId id) const { @@ -1074,7 +1075,7 @@ int64_t Model::next_linear_constraint_id() const { } bool Model::has_linear_constraint(const int64_t id) const { - return has_linear_constraint(LinearConstraintId(id)); + return id < 0 ? false : has_linear_constraint(LinearConstraintId(id)); } bool Model::has_linear_constraint(const LinearConstraintId id) const { @@ -1082,6 +1083,7 @@ bool Model::has_linear_constraint(const LinearConstraintId id) const { } LinearConstraint Model::linear_constraint(const int64_t id) const { + CHECK_GE(id, 0) << "negative linear constraint id: " << id; return linear_constraint(LinearConstraintId(id)); } @@ -1176,7 +1178,7 @@ int64_t Model::next_quadratic_constraint_id() const { } bool Model::has_quadratic_constraint(const int64_t id) const { - return has_quadratic_constraint(QuadraticConstraintId(id)); + return id < 0 ? false : has_quadratic_constraint(QuadraticConstraintId(id)); } bool Model::has_quadratic_constraint(const QuadraticConstraintId id) const { @@ -1184,6 +1186,7 @@ bool Model::has_quadratic_constraint(const QuadraticConstraintId id) const { } QuadraticConstraint Model::quadratic_constraint(const int64_t id) const { + CHECK_GE(id, 0) << "negative quadratic constraint id: " << id; return quadratic_constraint(QuadraticConstraintId(id)); } @@ -1219,7 +1222,9 @@ int64_t Model::next_second_order_cone_constraint_id() const { } bool Model::has_second_order_cone_constraint(const int64_t id) const { - return has_second_order_cone_constraint(SecondOrderConeConstraintId(id)); + return id < 0 ? false + : has_second_order_cone_constraint( + SecondOrderConeConstraintId(id)); } bool Model::has_second_order_cone_constraint( @@ -1265,7 +1270,7 @@ int64_t Model::next_sos1_constraint_id() const { } bool Model::has_sos1_constraint(const int64_t id) const { - return has_sos1_constraint(Sos1ConstraintId(id)); + return id < 0 ? false : has_sos1_constraint(Sos1ConstraintId(id)); } bool Model::has_sos1_constraint(const Sos1ConstraintId id) const { @@ -1306,7 +1311,7 @@ int64_t Model::next_sos2_constraint_id() const { } bool Model::has_sos2_constraint(const int64_t id) const { - return has_sos2_constraint(Sos2ConstraintId(id)); + return id < 0 ? false : has_sos2_constraint(Sos2ConstraintId(id)); } bool Model::has_sos2_constraint(const Sos2ConstraintId id) const { @@ -1347,7 +1352,7 @@ int64_t Model::next_indicator_constraint_id() const { } bool Model::has_indicator_constraint(const int64_t id) const { - return has_indicator_constraint(IndicatorConstraintId(id)); + return id < 0 ? false : has_indicator_constraint(IndicatorConstraintId(id)); } bool Model::has_indicator_constraint(const IndicatorConstraintId id) const { @@ -1555,7 +1560,7 @@ int64_t Model::next_auxiliary_objective_id() const { } bool Model::has_auxiliary_objective(const int64_t id) const { - return has_auxiliary_objective(AuxiliaryObjectiveId(id)); + return id < 0 ? false : has_auxiliary_objective(AuxiliaryObjectiveId(id)); } bool Model::has_auxiliary_objective(const AuxiliaryObjectiveId id) const { @@ -1563,6 +1568,7 @@ bool Model::has_auxiliary_objective(const AuxiliaryObjectiveId id) const { } Objective Model::auxiliary_objective(const int64_t id) const { + CHECK_GE(id, 0) << "negative auxiliary objective id: " << id; return auxiliary_objective(AuxiliaryObjectiveId(id)); } @@ -1618,14 +1624,15 @@ void Model::set_is_maximize(const Objective objective, const bool is_maximize) { storage()->set_is_maximize(objective.typed_id(), is_maximize); } -void Model::CheckOptionalModel(const ModelStorage* const other_storage) const { +void Model::CheckOptionalModel( + const NullableModelStorageCPtr other_storage) const { if (other_storage != nullptr) { CHECK_EQ(other_storage, storage()) << internal::kObjectsFromOtherModelStorage; } } -void Model::CheckModel(const ModelStorage* const other_storage) const { +void Model::CheckModel(const ModelStorageCPtr other_storage) const { CHECK_EQ(other_storage, storage()) << internal::kObjectsFromOtherModelStorage; } diff --git a/ortools/math_opt/cpp/model_solve_parameters.cc b/ortools/math_opt/cpp/model_solve_parameters.cc index 67e9b4cb17..7c7b9b7b9c 100644 --- a/ortools/math_opt/cpp/model_solve_parameters.cc +++ b/ortools/math_opt/cpp/model_solve_parameters.cc @@ -58,7 +58,7 @@ ModelSolveParameters ModelSolveParameters::OnlySomePrimalVariables( } absl::Status ModelSolveParameters::CheckModelStorage( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { for (const SolutionHint& hint : solution_hints) { RETURN_IF_ERROR(hint.CheckModelStorage(expected_storage)) << "invalid hint in solution_hints"; @@ -100,7 +100,7 @@ absl::Status ModelSolveParameters::CheckModelStorage( } absl::Status ModelSolveParameters::SolutionHint::CheckModelStorage( - const ModelStorage* expected_storage) const { + const ModelStorageCPtr expected_storage) const { for (const auto& [v, _] : variable_values) { RETURN_IF_ERROR(internal::CheckModelStorage( /*storage=*/v.storage(), diff --git a/ortools/math_opt/cpp/model_solve_parameters.h b/ortools/math_opt/cpp/model_solve_parameters.h index b09b499c3f..dbad2ba050 100644 --- a/ortools/math_opt/cpp/model_solve_parameters.h +++ b/ortools/math_opt/cpp/model_solve_parameters.h @@ -128,7 +128,7 @@ struct ModelSolveParameters { // Returns a failure if the referenced variables and constraints don't // belong to the input expected_storage (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // Returns the proto equivalent of this object. // @@ -215,7 +215,7 @@ struct ModelSolveParameters { // Returns a failure if the referenced variables and constraints do not belong // to the input expected_storage (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // Returns the proto equivalent of this object. // diff --git a/ortools/math_opt/cpp/objective.cc b/ortools/math_opt/cpp/objective.cc index bd8f3fa31d..6db709e2cd 100644 --- a/ortools/math_opt/cpp/objective.cc +++ b/ortools/math_opt/cpp/objective.cc @@ -31,18 +31,18 @@ LinearExpression Objective::AsLinearExpression() const { << "The objective function contains quadratic terms and cannot be " "represented as a LinearExpression"; LinearExpression objective = offset(); - for (const auto [raw_var_id, coeff] : storage_->linear_objective(id_)) { - objective += coeff * Variable(storage_, raw_var_id); + for (const auto [raw_var_id, coeff] : storage()->linear_objective(id_)) { + objective += coeff * Variable(storage(), raw_var_id); } return objective; } QuadraticExpression Objective::AsQuadraticExpression() const { QuadraticExpression result = offset(); - for (const auto& [v, coef] : storage_->linear_objective(id_)) { + for (const auto& [v, coef] : storage()->linear_objective(id_)) { result += coef * Variable(storage(), v); } - for (const auto& [v1, v2, coef] : storage_->quadratic_objective_terms(id_)) { + for (const auto& [v1, v2, coef] : storage()->quadratic_objective_terms(id_)) { result += QuadraticTerm(Variable(storage(), v1), Variable(storage(), v2), coef); } diff --git a/ortools/math_opt/cpp/objective.h b/ortools/math_opt/cpp/objective.h index 5657369947..389e242b34 100644 --- a/ortools/math_opt/cpp/objective.h +++ b/ortools/math_opt/cpp/objective.h @@ -25,10 +25,10 @@ #include "absl/log/check.h" #include "absl/strings/string_view.h" -#include "ortools/base/strong_int.h" #include "ortools/math_opt/cpp/key_types.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" #include "ortools/math_opt/storage/model_storage_types.h" namespace operations_research::math_opt { @@ -40,15 +40,15 @@ constexpr absl::string_view kDeletedObjectiveDefaultDescription = // ModelStorage. Usually this type is passed by copy. // // This type implements https://abseil.io/docs/cpp/guides/hash. -class Objective { +class Objective final : public ModelStorageItem { public: // The type used for ids. using IdType = AuxiliaryObjectiveId; // Returns an object that refers to the primary objective of the model. - inline static Objective Primary(const ModelStorage* storage); + inline static Objective Primary(ModelStorageCPtr storage); // Returns an object that refers to an auxiliary objective of the model. - inline static Objective Auxiliary(const ModelStorage* storage, + inline static Objective Auxiliary(ModelStorageCPtr storage, AuxiliaryObjectiveId id); // Returns the raw integer ID associated with the objective: nullopt for the @@ -57,8 +57,6 @@ class Objective { // Returns the strong int ID associated with the objective: nullopt for the // primary objective, an AuxiliaryObjectiveId for an auxiliary objective. inline ObjectiveId typed_id() const; - // Returns a const-pointer to the underlying storage object for the model. - inline const ModelStorage* storage() const; // Returns true if the ID corresponds to the primary objective, and false if // it is an auxiliary objective. @@ -113,9 +111,8 @@ class Objective { const Objective& objective); private: - inline Objective(const ModelStorage* storage, ObjectiveId id); + inline Objective(ModelStorageCPtr storage, ObjectiveId id); - const ModelStorage* storage_; ObjectiveId id_; }; @@ -139,68 +136,66 @@ std::optional Objective::id() const { ObjectiveId Objective::typed_id() const { return id_; } -const ModelStorage* Objective::storage() const { return storage_; } - bool Objective::is_primary() const { return id_ == kPrimaryObjectiveId; } int64_t Objective::priority() const { - return storage_->objective_priority(id_); + return storage()->objective_priority(id_); } -bool Objective::maximize() const { return storage_->is_maximize(id_); } +bool Objective::maximize() const { return storage()->is_maximize(id_); } absl::string_view Objective::name() const { - if (is_primary() || storage_->has_auxiliary_objective(*id_)) { - return storage_->objective_name(id_); + if (is_primary() || storage()->has_auxiliary_objective(*id_)) { + return storage()->objective_name(id_); } return kDeletedObjectiveDefaultDescription; } -double Objective::offset() const { return storage_->objective_offset(id_); } +double Objective::offset() const { return storage()->objective_offset(id_); } int64_t Objective::num_quadratic_terms() const { - return storage_->num_quadratic_objective_terms(id_); + return storage()->num_quadratic_objective_terms(id_); } int64_t Objective::num_linear_terms() const { - return storage_->num_linear_objective_terms(id_); + return storage()->num_linear_objective_terms(id_); } double Objective::coefficient(const Variable variable) const { - CHECK_EQ(variable.storage(), storage_) + CHECK_EQ(variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->linear_objective_coefficient(id_, variable.typed_id()); + return storage()->linear_objective_coefficient(id_, variable.typed_id()); } double Objective::coefficient(const Variable first_variable, const Variable second_variable) const { - CHECK_EQ(first_variable.storage(), storage_) + CHECK_EQ(first_variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - CHECK_EQ(second_variable.storage(), storage_) + CHECK_EQ(second_variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->quadratic_objective_coefficient( + return storage()->quadratic_objective_coefficient( id_, first_variable.typed_id(), second_variable.typed_id()); } bool Objective::is_coefficient_nonzero(const Variable variable) const { - CHECK_EQ(variable.storage(), storage_) + CHECK_EQ(variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->is_linear_objective_coefficient_nonzero(id_, - variable.typed_id()); + return storage()->is_linear_objective_coefficient_nonzero( + id_, variable.typed_id()); } bool Objective::is_coefficient_nonzero(const Variable first_variable, const Variable second_variable) const { - CHECK_EQ(first_variable.storage(), storage_) + CHECK_EQ(first_variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - CHECK_EQ(second_variable.storage(), storage_) + CHECK_EQ(second_variable.storage(), storage()) << internal::kObjectsFromOtherModelStorage; - return storage_->is_quadratic_objective_coefficient_nonzero( + return storage()->is_quadratic_objective_coefficient_nonzero( id_, first_variable.typed_id(), second_variable.typed_id()); } bool operator==(const Objective& lhs, const Objective& rhs) { - return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_; + return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage(); } bool operator!=(const Objective& lhs, const Objective& rhs) { @@ -209,17 +204,17 @@ bool operator!=(const Objective& lhs, const Objective& rhs) { template H AbslHashValue(H h, const Objective& objective) { - return H::combine(std::move(h), objective.id_, objective.storage_); + return H::combine(std::move(h), objective.id_, objective.storage()); } -Objective::Objective(const ModelStorage* const storage, const ObjectiveId id) - : storage_(storage), id_(id) {} +Objective::Objective(const ModelStorageCPtr storage, const ObjectiveId id) + : ModelStorageItem(storage), id_(id) {} -Objective Objective::Primary(const ModelStorage* const storage) { +Objective Objective::Primary(const ModelStorageCPtr storage) { return Objective(storage, kPrimaryObjectiveId); } -Objective Objective::Auxiliary(const ModelStorage* const storage, +Objective Objective::Auxiliary(const ModelStorageCPtr storage, const AuxiliaryObjectiveId id) { return Objective(storage, id); } diff --git a/ortools/math_opt/cpp/parameters.cc b/ortools/math_opt/cpp/parameters.cc index 87ccbb9a83..d0826f5e0c 100644 --- a/ortools/math_opt/cpp/parameters.cc +++ b/ortools/math_opt/cpp/parameters.cc @@ -13,7 +13,6 @@ #include "ortools/math_opt/cpp/parameters.h" -#include #include #include #include @@ -86,7 +85,7 @@ std::optional Enum::ToOptString( case SolverType::kSantorini: return "santorini"; case SolverType::kXpress: - return "xpress"; + return "xpress"; } return std::nullopt; } @@ -96,7 +95,7 @@ absl::Span Enum::AllValues() { SolverType::kGscip, SolverType::kGurobi, SolverType::kGlop, SolverType::kCpSat, SolverType::kPdlp, SolverType::kGlpk, SolverType::kEcos, SolverType::kScs, SolverType::kHighs, - SolverType::kSantorini, + SolverType::kSantorini, SolverType::kXpress, }; return absl::MakeConstSpan(kSolverTypeValues); } diff --git a/ortools/math_opt/cpp/parameters.h b/ortools/math_opt/cpp/parameters.h index cb44d3d443..ff94c7c1e9 100644 --- a/ortools/math_opt/cpp/parameters.h +++ b/ortools/math_opt/cpp/parameters.h @@ -24,7 +24,6 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" -#include "absl/types/span.h" #include "ortools/base/linked_hash_map.h" #include "ortools/glop/parameters.pb.h" // IWYU pragma: export #include "ortools/gscip/gscip.pb.h" // IWYU pragma: export @@ -114,7 +113,7 @@ enum class SolverType { // // Supports LP, MIP, and nonconvex integer quadratic problems. // A fast option, but has special licensing. - kXpress = SOLVER_TYPE_XPRESS + kXpress = SOLVER_TYPE_XPRESS, }; MATH_OPT_DEFINE_ENUM(SolverType, SOLVER_TYPE_UNSPECIFIED); diff --git a/ortools/math_opt/cpp/solution.cc b/ortools/math_opt/cpp/solution.cc index cffc453efe..f43468eeb5 100644 --- a/ortools/math_opt/cpp/solution.cc +++ b/ortools/math_opt/cpp/solution.cc @@ -18,10 +18,10 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "ortools/base/logging.h" #include "ortools/base/status_builder.h" #include "ortools/base/status_macros.h" #include "ortools/math_opt/cpp/sparse_containers.h" @@ -57,7 +57,7 @@ absl::Span Enum::AllValues() { } absl::StatusOr PrimalSolution::FromProto( - const ModelStorage* model, + const ModelStorageCPtr model, const PrimalSolutionProto& primal_solution_proto) { PrimalSolution primal_solution; OR_ASSIGN_OR_RETURN3( @@ -104,7 +104,7 @@ double PrimalSolution::get_objective_value(const Objective objective) const { } absl::StatusOr PrimalRay::FromProto( - const ModelStorage* model, const PrimalRayProto& primal_ray_proto) { + const ModelStorageCPtr model, const PrimalRayProto& primal_ray_proto) { PrimalRay result; OR_ASSIGN_OR_RETURN3( result.variable_values, @@ -120,7 +120,8 @@ PrimalRayProto PrimalRay::Proto() const { } absl::StatusOr DualSolution::FromProto( - const ModelStorage* model, const DualSolutionProto& dual_solution_proto) { + const ModelStorageCPtr model, + const DualSolutionProto& dual_solution_proto) { DualSolution dual_solution; OR_ASSIGN_OR_RETURN3( dual_solution.dual_values, @@ -159,7 +160,7 @@ DualSolutionProto DualSolution::Proto() const { return result; } -absl::StatusOr DualRay::FromProto(const ModelStorage* model, +absl::StatusOr DualRay::FromProto(const ModelStorageCPtr model, const DualRayProto& dual_ray_proto) { DualRay result; OR_ASSIGN_OR_RETURN3( @@ -180,7 +181,7 @@ DualRayProto DualRay::Proto() const { return result; } -absl::StatusOr Basis::FromProto(const ModelStorage* model, +absl::StatusOr Basis::FromProto(const ModelStorageCPtr model, const BasisProto& basis_proto) { Basis basis; OR_ASSIGN_OR_RETURN3( @@ -197,7 +198,7 @@ absl::StatusOr Basis::FromProto(const ModelStorage* model, } absl::Status Basis::CheckModelStorage( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { for (const auto& [v, _] : variable_status) { RETURN_IF_ERROR( internal::CheckModelStorage(/*storage=*/v.storage(), @@ -223,7 +224,7 @@ BasisProto Basis::Proto() const { } absl::StatusOr Solution::FromProto( - const ModelStorage* model, const SolutionProto& solution_proto) { + const ModelStorageCPtr model, const SolutionProto& solution_proto) { Solution solution; if (solution_proto.has_primal_solution()) { OR_ASSIGN_OR_RETURN3( diff --git a/ortools/math_opt/cpp/solution.h b/ortools/math_opt/cpp/solution.h index 4177e3d676..a4218d2022 100644 --- a/ortools/math_opt/cpp/solution.h +++ b/ortools/math_opt/cpp/solution.h @@ -67,8 +67,7 @@ struct PrimalSolution { // * VariableValuesFromProto(primal_solution_proto.variable_values) fails. // * the feasibility_status is not specified. static absl::StatusOr FromProto( - const ModelStorage* model, - const PrimalSolutionProto& primal_solution_proto); + ModelStorageCPtr model, const PrimalSolutionProto& primal_solution_proto); // Returns the proto equivalent of this. PrimalSolutionProto Proto() const; @@ -112,7 +111,7 @@ struct PrimalRay { // Returns an error when // VariableValuesFromProto(primal_ray_proto.variable_values) fails. static absl::StatusOr FromProto( - const ModelStorage* model, const PrimalRayProto& primal_ray_proto); + ModelStorageCPtr model, const PrimalRayProto& primal_ray_proto); // Returns the proto equivalent of this. PrimalRayProto Proto() const; @@ -139,7 +138,7 @@ struct DualSolution { // * LinearConstraintValuesFromProto(dual_solution_proto.dual_values) fails. // * dual_solution_proto.feasibility_status is not specified. static absl::StatusOr FromProto( - const ModelStorage* model, const DualSolutionProto& dual_solution_proto); + ModelStorageCPtr model, const DualSolutionProto& dual_solution_proto); // Returns the proto equivalent of this. DualSolutionProto Proto() const; @@ -179,7 +178,7 @@ struct DualRay { // Returns an error when either of: // * VariableValuesFromProto(dual_ray_proto.reduced_costs) fails. // * LinearConstraintValuesFromProto(dual_ray_proto.dual_values) fails. - static absl::StatusOr FromProto(const ModelStorage* model, + static absl::StatusOr FromProto(ModelStorageCPtr model, const DualRayProto& dual_ray_proto); // Returns the proto equivalent of this. @@ -218,12 +217,12 @@ struct Basis { // Returns an error if: // * VariableBasisFromProto(basis_proto.variable_status) fails. // * LinearConstraintBasisFromProto(basis_proto.constraint_status) fails. - static absl::StatusOr FromProto(const ModelStorage* model, + static absl::StatusOr FromProto(ModelStorageCPtr model, const BasisProto& basis_proto); // Returns a failure if the referenced variables don't belong to the input // expected_storage (which must not be nullptr). - absl::Status CheckModelStorage(const ModelStorage* expected_storage) const; + absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const; // Returns the proto equivalent of this object. // @@ -262,7 +261,7 @@ struct Solution { // Returns an error if FromProto() fails on any field that is not std::nullopt // (see the static FromProto() functions for each field type for details). static absl::StatusOr FromProto( - const ModelStorage* model, const SolutionProto& solution_proto); + ModelStorageCPtr model, const SolutionProto& solution_proto); // Returns the proto equivalent of this. SolutionProto Proto() const; diff --git a/ortools/math_opt/cpp/solve_arguments.cc b/ortools/math_opt/cpp/solve_arguments.cc index 2c25f70ecb..c2df0d6b7a 100644 --- a/ortools/math_opt/cpp/solve_arguments.cc +++ b/ortools/math_opt/cpp/solve_arguments.cc @@ -25,7 +25,7 @@ namespace operations_research::math_opt { absl::Status SolveArguments::CheckModelStorageAndCallback( - const ModelStorage* const expected_storage) const { + const ModelStorageCPtr expected_storage) const { RETURN_IF_ERROR(model_parameters.CheckModelStorage(expected_storage)) << "invalid model_parameters"; RETURN_IF_ERROR(callback_registration.CheckModelStorage(expected_storage)) diff --git a/ortools/math_opt/cpp/solve_arguments.h b/ortools/math_opt/cpp/solve_arguments.h index 4738532fc1..2bc252705e 100644 --- a/ortools/math_opt/cpp/solve_arguments.h +++ b/ortools/math_opt/cpp/solve_arguments.h @@ -88,7 +88,7 @@ struct SolveArguments { // to the input expected_storage (which must not be nullptr). Also returns a // failure if callback events are registered but no callback is provided. absl::Status CheckModelStorageAndCallback( - const ModelStorage* expected_storage) const; + ModelStorageCPtr expected_storage) const; }; } // namespace operations_research::math_opt diff --git a/ortools/math_opt/cpp/solve_impl.cc b/ortools/math_opt/cpp/solve_impl.cc index a886b2870b..b77443820f 100644 --- a/ortools/math_opt/cpp/solve_impl.cc +++ b/ortools/math_opt/cpp/solve_impl.cc @@ -40,9 +40,10 @@ namespace operations_research::math_opt::internal { namespace { -absl::StatusOr CallSolve( - BaseSolver& solver, const ModelStorage* const expected_storage, - const SolveArguments& arguments, SolveInterrupter& local_canceller) { +absl::StatusOr CallSolve(BaseSolver& solver, + const ModelStorageCPtr expected_storage, + const SolveArguments& arguments, + SolveInterrupter& local_canceller) { RETURN_IF_ERROR(arguments.CheckModelStorageAndCallback(expected_storage)); BaseSolver::Callback cb = nullptr; @@ -104,7 +105,7 @@ absl::StatusOr CallSolve( } absl::StatusOr CallComputeInfeasibleSubsystem( - BaseSolver& solver, const ModelStorage* const expected_storage, + BaseSolver& solver, const ModelStorageCPtr expected_storage, const ComputeInfeasibleSubsystemArguments& arguments, SolveInterrupter& local_canceller) { ASSIGN_OR_RETURN( @@ -180,7 +181,7 @@ IncrementalSolverImpl::IncrementalSolverImpl( BaseSolverFactory solver_factory, SolverType solver_type, const bool remove_names, std::shared_ptr local_canceller, std::unique_ptr user_canceller_cb, - const ModelStorage* const expected_storage, + const ModelStorageCPtr expected_storage, std::unique_ptr update_tracker, std::unique_ptr solver) : solver_factory_(std::move(solver_factory)), diff --git a/ortools/math_opt/cpp/solve_impl.h b/ortools/math_opt/cpp/solve_impl.h index 7f52a0d9e5..803673063e 100644 --- a/ortools/math_opt/cpp/solve_impl.h +++ b/ortools/math_opt/cpp/solve_impl.h @@ -102,7 +102,7 @@ class IncrementalSolverImpl : public IncrementalSolver { BaseSolverFactory solver_factory, SolverType solver_type, bool remove_names, std::shared_ptr local_canceller, std::unique_ptr user_canceller_cb, - const ModelStorage* expected_storage, + ModelStorageCPtr expected_storage, std::unique_ptr update_tracker, std::unique_ptr solver); @@ -114,7 +114,7 @@ class IncrementalSolverImpl : public IncrementalSolver { // can be destroyed after local_canceller_ without risk. std::shared_ptr local_canceller_; std::unique_ptr user_canceller_cb_; - const ModelStorage* const expected_storage_; + const ModelStorageCPtr expected_storage_; const std::unique_ptr update_tracker_; std::unique_ptr solver_; }; diff --git a/ortools/math_opt/cpp/solve_result.cc b/ortools/math_opt/cpp/solve_result.cc index 9b81365e9c..ae0cab0da0 100644 --- a/ortools/math_opt/cpp/solve_result.cc +++ b/ortools/math_opt/cpp/solve_result.cc @@ -536,7 +536,7 @@ TerminationProto UpgradedTerminationProtoForStatsMigration( } // namespace absl::StatusOr SolveResult::FromProto( - const ModelStorage* model, const SolveResultProto& solve_result_proto) { + const ModelStorageCPtr model, const SolveResultProto& solve_result_proto) { OR_ASSIGN_OR_RETURN3( auto termination, Termination::FromProto( diff --git a/ortools/math_opt/cpp/solve_result.h b/ortools/math_opt/cpp/solve_result.h index 33e54d538a..42bee7fc75 100644 --- a/ortools/math_opt/cpp/solve_result.h +++ b/ortools/math_opt/cpp/solve_result.h @@ -518,7 +518,7 @@ struct SolveResult { // validation, or not rely on the strong guarantees of ValidateResult() // and just treat SolveResult as a simple struct. static absl::StatusOr FromProto( - const ModelStorage* model, const SolveResultProto& solve_result_proto); + ModelStorageCPtr model, const SolveResultProto& solve_result_proto); // Returns the proto equivalent of this. // diff --git a/ortools/math_opt/cpp/sparse_containers.cc b/ortools/math_opt/cpp/sparse_containers.cc index 4ca4bef7ba..63a0ea3f33 100644 --- a/ortools/math_opt/cpp/sparse_containers.cc +++ b/ortools/math_opt/cpp/sparse_containers.cc @@ -49,8 +49,7 @@ absl::Status CheckSparseVectorProto(const SparseVectorProtoType& vec) { template absl::StatusOr> BasisVectorFromProto( - const ModelStorage* const model, - const SparseBasisStatusVector& basis_proto) { + const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) { using IdType = typename Key::IdType; absl::flat_hash_map map; map.reserve(basis_proto.ids_size()); @@ -104,7 +103,7 @@ SparseBasisStatusVector BasisMapToProto( return result; } -absl::Status VariableIdsExist(const ModelStorage* const model, +absl::Status VariableIdsExist(const ModelStorageCPtr model, const absl::Span ids) { for (const int64_t id : ids) { if (!model->has_variable(VariableId(id))) { @@ -115,7 +114,7 @@ absl::Status VariableIdsExist(const ModelStorage* const model, return absl::OkStatus(); } -absl::Status LinearConstraintIdsExist(const ModelStorage* const model, +absl::Status LinearConstraintIdsExist(const ModelStorageCPtr model, const absl::Span ids) { for (const int64_t id : ids) { if (!model->has_linear_constraint(LinearConstraintId(id))) { @@ -126,7 +125,7 @@ absl::Status LinearConstraintIdsExist(const ModelStorage* const model, return absl::OkStatus(); } -absl::Status QuadraticConstraintIdsExist(const ModelStorage* const model, +absl::Status QuadraticConstraintIdsExist(const ModelStorageCPtr model, const absl::Span ids) { for (const int64_t id : ids) { if (!model->has_constraint(QuadraticConstraintId(id))) { @@ -140,15 +139,14 @@ absl::Status QuadraticConstraintIdsExist(const ModelStorage* const model, } // namespace absl::StatusOr> VariableValuesFromProto( - const ModelStorage* const model, - const SparseDoubleVectorProto& vars_proto) { + const ModelStorageCPtr model, const SparseDoubleVectorProto& vars_proto) { RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto)); RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids())); return MakeView(vars_proto).as_map(model); } absl::StatusOr> VariableValuesFromProto( - const ModelStorage* model, const SparseInt32VectorProto& vars_proto) { + const ModelStorageCPtr model, const SparseInt32VectorProto& vars_proto) { RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto)); RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids())); return MakeView(vars_proto).as_map(model); @@ -161,7 +159,7 @@ SparseDoubleVectorProto VariableValuesToProto( absl::StatusOr> AuxiliaryObjectiveValuesFromProto( - const ModelStorage* const model, + const ModelStorageCPtr model, const google::protobuf::Map& aux_obj_proto) { absl::flat_hash_map result; for (const auto [raw_id, value] : aux_obj_proto) { @@ -187,7 +185,7 @@ google::protobuf::Map AuxiliaryObjectiveValuesToProto( } absl::StatusOr> LinearConstraintValuesFromProto( - const ModelStorage* const model, + const ModelStorageCPtr model, const SparseDoubleVectorProto& lin_cons_proto) { RETURN_IF_ERROR(CheckSparseVectorProto(lin_cons_proto)); RETURN_IF_ERROR(LinearConstraintIdsExist(model, lin_cons_proto.ids())); @@ -201,7 +199,7 @@ SparseDoubleVectorProto LinearConstraintValuesToProto( absl::StatusOr> QuadraticConstraintValuesFromProto( - const ModelStorage* const model, + const ModelStorageCPtr model, const SparseDoubleVectorProto& quad_cons_proto) { RETURN_IF_ERROR(CheckSparseVectorProto(quad_cons_proto)); RETURN_IF_ERROR(QuadraticConstraintIdsExist(model, quad_cons_proto.ids())); @@ -215,8 +213,7 @@ SparseDoubleVectorProto QuadraticConstraintValuesToProto( } absl::StatusOr> VariableBasisFromProto( - const ModelStorage* const model, - const SparseBasisStatusVector& basis_proto) { + const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) { RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto)); RETURN_IF_ERROR(VariableIdsExist(model, basis_proto.ids())); return BasisVectorFromProto(model, basis_proto); @@ -228,8 +225,7 @@ SparseBasisStatusVector VariableBasisToProto( } absl::StatusOr> LinearConstraintBasisFromProto( - const ModelStorage* const model, - const SparseBasisStatusVector& basis_proto) { + const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) { RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto)); RETURN_IF_ERROR(LinearConstraintIdsExist(model, basis_proto.ids())); return BasisVectorFromProto(model, basis_proto); diff --git a/ortools/math_opt/cpp/sparse_containers.h b/ortools/math_opt/cpp/sparse_containers.h index f6411afd28..9627fc2431 100644 --- a/ortools/math_opt/cpp/sparse_containers.h +++ b/ortools/math_opt/cpp/sparse_containers.h @@ -42,7 +42,7 @@ namespace operations_research::math_opt { // // Note that the values of vars_proto.values are not checked (it may have NaNs). absl::StatusOr> VariableValuesFromProto( - const ModelStorage* model, const SparseDoubleVectorProto& vars_proto); + ModelStorageCPtr model, const SparseDoubleVectorProto& vars_proto); // Returns the VariableMap equivalent to `vars_proto`. // @@ -52,7 +52,7 @@ absl::StatusOr> VariableValuesFromProto( // * vars_proto.ids has elements that are variables in `model` (this implies // that each id is in [0, max(int64_t))). absl::StatusOr> VariableValuesFromProto( - const ModelStorage* model, const SparseInt32VectorProto& vars_proto); + ModelStorageCPtr model, const SparseInt32VectorProto& vars_proto); // Returns the proto equivalent of variable_values. SparseDoubleVectorProto VariableValuesToProto( @@ -67,7 +67,7 @@ SparseDoubleVectorProto VariableValuesToProto( // Note that the values of `aux_obj_proto` are not checked (it may have NaNs). absl::StatusOr> AuxiliaryObjectiveValuesFromProto( - const ModelStorage* model, + ModelStorageCPtr model, const google::protobuf::Map& aux_obj_proto); // Returns the proto equivalent of auxiliary_obj_values. @@ -88,7 +88,7 @@ google::protobuf::Map AuxiliaryObjectiveValuesToProto( // Note that the values of lin_cons_proto.values are not checked (it may have // NaNs). absl::StatusOr> LinearConstraintValuesFromProto( - const ModelStorage* model, const SparseDoubleVectorProto& lin_cons_proto); + ModelStorageCPtr model, const SparseDoubleVectorProto& lin_cons_proto); // Returns the proto equivalent of linear_constraint_values. SparseDoubleVectorProto LinearConstraintValuesToProto( @@ -107,7 +107,7 @@ SparseDoubleVectorProto LinearConstraintValuesToProto( // NaNs). absl::StatusOr> QuadraticConstraintValuesFromProto( - const ModelStorage* model, const SparseDoubleVectorProto& quad_cons_proto); + ModelStorageCPtr model, const SparseDoubleVectorProto& quad_cons_proto); // Returns the proto equivalent of quadratic_constraint_values. SparseDoubleVectorProto QuadraticConstraintValuesToProto( @@ -123,7 +123,7 @@ SparseDoubleVectorProto QuadraticConstraintValuesToProto( // that each id is in [0, max(int64_t))). // * basis_proto.values does not contain UNSPECIFIED and has valid enum values. absl::StatusOr> VariableBasisFromProto( - const ModelStorage* model, const SparseBasisStatusVector& basis_proto); + ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto); // Returns the proto equivalent of basis_values. SparseBasisStatusVector VariableBasisToProto( @@ -138,7 +138,7 @@ SparseBasisStatusVector VariableBasisToProto( // implies that each id is in [0, max(int64_t))). // * basis_proto.values does not contain UNSPECIFIED and has valid enum values. absl::StatusOr> LinearConstraintBasisFromProto( - const ModelStorage* model, const SparseBasisStatusVector& basis_proto); + ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto); // Returns the proto equivalent of basis_values. SparseBasisStatusVector LinearConstraintBasisToProto( diff --git a/ortools/math_opt/cpp/variable_and_expressions.cc b/ortools/math_opt/cpp/variable_and_expressions.cc index 0f50067b00..d3fbcd3d62 100644 --- a/ortools/math_opt/cpp/variable_and_expressions.cc +++ b/ortools/math_opt/cpp/variable_and_expressions.cc @@ -24,6 +24,9 @@ #include "ortools/base/map_util.h" #include "ortools/base/strong_int.h" #include "ortools/math_opt/cpp/formatters.h" +#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS +#include "ortools/math_opt/storage/model_storage_item.h" +#endif // MATH_OPT_USE_EXPRESSION_COUNTERS #include "ortools/util/fp_roundtrip_conv.h" namespace operations_research { @@ -35,7 +38,9 @@ constexpr double kInf = std::numeric_limits::infinity(); LinearExpression::LinearExpression() { ++num_calls_default_constructor_; } LinearExpression::LinearExpression(const LinearExpression& other) - : storage_(other.storage_), terms_(other.terms_), offset_(other.offset_) { + : ModelStorageItemContainer(other.storage()), + terms_(other.terms_), + offset_(other.offset_) { ++num_calls_copy_constructor_; } @@ -203,7 +208,7 @@ std::ostream& operator<<(std::ostream& ostr, QuadraticExpression::QuadraticExpression() { ++num_calls_default_constructor_; } QuadraticExpression::QuadraticExpression(const QuadraticExpression& other) - : storage_(other.storage_), + : ModelStorageItemContainer(other), quadratic_terms_(other.quadratic_terms_), linear_terms_(other.linear_terms_), offset_(other.offset_) { diff --git a/ortools/math_opt/cpp/variable_and_expressions.h b/ortools/math_opt/cpp/variable_and_expressions.h index 6cda53489c..5bd36030e6 100644 --- a/ortools/math_opt/cpp/variable_and_expressions.h +++ b/ortools/math_opt/cpp/variable_and_expressions.h @@ -103,11 +103,11 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/string_view.h" -#include "ortools/base/logging.h" -#include "ortools/base/strong_int.h" #include "ortools/math_opt/cpp/key_types.h" // IWYU pragma: export #include "ortools/math_opt/storage/model_storage.h" +#include "ortools/math_opt/storage/model_storage_item.h" #include "ortools/math_opt/storage/model_storage_types.h" namespace operations_research { @@ -118,40 +118,20 @@ class LinearExpression; // A value type that references a variable from ModelStorage. Usually this type // is passed by copy. -// -// This type implements https://abseil.io/docs/cpp/guides/hash (see -// VariablesEquality for details about how operator== works). -class Variable { +class Variable final : public ModelStorageElement< + ElementType::kVariable, Variable, + // This type has a special equality operator + // (see `VariablesEquality` below). + ModelStorageElementEquality::kWithoutEquality> { public: - // The typed integer used for ids. - using IdType = VariableId; - - // Usually users will obtain variables using Model::AddVariable(). There - // should be little for users to build this object from an ModelStorage. - inline Variable(const ModelStorage* storage, VariableId id); - - // Each call to AddVariable will produce Variables id() increasing by one, - // starting at zero. Deleted ids are NOT reused. Thus, if no variables are - // deleted, the ids in the model will be consecutive. - inline int64_t id() const; - - inline VariableId typed_id() const; - inline const ModelStorage* storage() const; + using ModelStorageElement::ModelStorageElement; inline double lower_bound() const; inline double upper_bound() const; inline bool is_integer() const; inline absl::string_view name() const; - template - friend H AbslHashValue(H h, const Variable& variable); - friend std::ostream& operator<<(std::ostream& ostr, const Variable& variable); - inline LinearExpression operator-() const; - - private: - const ModelStorage* storage_; - VariableId id_; }; namespace internal { @@ -186,8 +166,6 @@ inline bool operator!=(const Variable& lhs, const Variable& rhs); template using VariableMap = absl::flat_hash_map; -inline std::ostream& operator<<(std::ostream& ostr, const Variable& variable); - // A term in an sum of variables multiplied by coefficients. struct LinearTerm { // Usually this constructor is never called explicitly by users. Instead it @@ -226,7 +204,7 @@ class QuadraticExpression; // TODO(b/169415098): add a function to remove zero terms. // TODO(b/169415834): study if exact zeros should be automatically removed. // TODO(b/169415103): add tests that some expressions don't compile. -class LinearExpression { +class LinearExpression final : public ModelStorageItemContainer { public: // For unit testing purpose, we define optional counters. We have to // explicitly define the default constructor, copy constructor and assignment @@ -238,9 +216,6 @@ class LinearExpression { LinearExpression(); LinearExpression(const LinearExpression& other); #endif // MATH_OPT_USE_EXPRESSION_COUNTERS - // We have to define a custom move constructor as we need to reset storage_ to - // nullptr. - inline LinearExpression(LinearExpression&& other) noexcept; // Usually users should use the overloads of operators to build linear // expressions. For example, assuming `x` and `y` are Variable, then `x + 2*y // + 5` will build a LinearExpression automatically. @@ -250,8 +225,9 @@ class LinearExpression { inline LinearExpression(Variable variable); // NOLINT inline LinearExpression(const LinearTerm& term); // NOLINT LinearExpression& operator=(const LinearExpression& other) = default; - // We have to define a custom move assignment operator as we need to reset - // storage_ to nullptr. + // A moved-from `LinearExpression` is the zero expression: it's not associated + // to a storage, has no terms and its offset is zero. + inline LinearExpression(LinearExpression&& other) noexcept; inline LinearExpression& operator=(LinearExpression&& other) noexcept; inline LinearExpression& operator+=(const LinearExpression& other); @@ -367,8 +343,6 @@ class LinearExpression { double EvaluateWithDefaultZero( const VariableMap& variable_values) const; - inline const ModelStorage* storage() const; - #ifdef MATH_OPT_USE_EXPRESSION_COUNTERS static thread_local int num_calls_default_constructor_; static thread_local int num_calls_copy_constructor_; @@ -384,14 +358,9 @@ class LinearExpression { const LinearExpression& expression); friend QuadraticExpression; - // Sets the storage_ to the input value if nullptr, else CHECKs that it is - // equal. Also CHECKs that the input value is not nullptr. - inline void SetOrCheckStorage(const ModelStorage* storage); - // Invariants: - // * nullptr, if terms_ is empty - // * equal to Variable::storage() of each key of terms_, else - const ModelStorage* storage_ = nullptr; + // * storage() == v.storage()for each v in terms_; + // * storage() == nullptr, if terms_ is empty. VariableMap terms_; double offset_ = 0.0; }; @@ -647,14 +616,14 @@ using QuadraticProductId = std::pair; // silently correct this if not satisfied by the inputs. // // This type can be used as a key in ABSL hash containers. -class QuadraticTermKey { +class QuadraticTermKey final : public ModelStorageItem { public: // NOTE: this definition is for use by IdMap; clients should not rely upon it. using IdType = QuadraticProductId; // NOTE: This constructor will silently re-order the passed id so that, upon // exiting the constructor, variable_ids_.first <= variable_ids_.second. - inline QuadraticTermKey(const ModelStorage* storage, QuadraticProductId id); + inline QuadraticTermKey(ModelStorageCPtr storage, QuadraticProductId id); // NOTE: This constructor will CHECK fail if the variable models do not agree, // i.e. first_variable.storage() != second_variable.storage(). It will also // silently re-order the passed id so that, upon exiting the constructor, @@ -662,19 +631,17 @@ class QuadraticTermKey { inline QuadraticTermKey(Variable first_variable, Variable second_variable); inline QuadraticProductId typed_id() const; - inline const ModelStorage* storage() const; // Returns the Variable with the smallest id. - Variable first() const { return Variable(storage_, variable_ids_.first); } + Variable first() const { return Variable(storage(), variable_ids_.first); } // Returns the Variable the largest id. - Variable second() const { return Variable(storage_, variable_ids_.second); } + Variable second() const { return Variable(storage(), variable_ids_.second); } template friend H AbslHashValue(H h, const QuadraticTermKey& key); private: - const ModelStorage* storage_; QuadraticProductId variable_ids_; }; @@ -744,7 +711,7 @@ using QuadraticTermMap = absl::flat_hash_map; // is, it is forbidden that both are non-null and not equal. Use // CheckModelsAgree() and the initializer_list constructor to enforce this // invariant in any class or friend method. -class QuadraticExpression { +class QuadraticExpression final : public ModelStorageItemContainer { public: // For unit testing purpose, we define optional counters. We have to // explicitly define the default constructor, copy constructor and assignment @@ -756,9 +723,6 @@ class QuadraticExpression { QuadraticExpression(); QuadraticExpression(const QuadraticExpression& other); #endif // MATH_OPT_USE_EXPRESSION_COUNTERS - // We have to define a custom move constructor as we need to reset storage_ to - // nullptr. - inline QuadraticExpression(QuadraticExpression&& other) noexcept; // Users should prefer the default constructor and operator overloads to build // expressions. inline QuadraticExpression( @@ -770,8 +734,9 @@ class QuadraticExpression { inline QuadraticExpression(LinearExpression expr); // NOLINT inline QuadraticExpression(const QuadraticTerm& term); // NOLINT QuadraticExpression& operator=(const QuadraticExpression& other) = default; - // We have to define a custom move assignment operator as we need to reset - // storage_ to nullptr. + // A moved-from `LinearExpression` is the zero expression: it's not associated + // to a storage, has no terms and its offset is zero. + inline QuadraticExpression(QuadraticExpression&& other) noexcept; inline QuadraticExpression& operator=(QuadraticExpression&& other) noexcept; inline double offset() const; @@ -933,8 +898,6 @@ class QuadraticExpression { double EvaluateWithDefaultZero( const VariableMap& variable_values) const; - inline const ModelStorage* storage() const; - #ifdef MATH_OPT_USE_EXPRESSION_COUNTERS static thread_local int num_calls_default_constructor_; static thread_local int num_calls_copy_constructor_; @@ -950,15 +913,10 @@ class QuadraticExpression { friend std::ostream& operator<<(std::ostream& ostr, const QuadraticExpression& expr); - // Sets the storage_ to the input value if nullptr, else CHECKs that it is - // equal. Also CHECKs that the input value is not nullptr. - inline void SetOrCheckStorage(const ModelStorage* storage); - // Invariants: - // * nullptr, if both quadratic_terms_ and linear_terms_ are empty - // * equal to Variable::storage() of each key of linear_terms_ and - // QuadraticTermKey::storage() of each key of quadratic_terms_, else - const ModelStorage* storage_ = nullptr; + // * storage() == v.storage() for each v in linear_terms_; + // * storage() == v.storage() for each v in quadratic_terms_; + // * storage() == nullptr, if both terms_ and quadratic_terms_ are empty. QuadraticTermMap quadratic_terms_; VariableMap linear_terms_; double offset_ = 0.0; @@ -1268,50 +1226,25 @@ inline BoundedQuadraticExpression operator==(double lhs, // Variable //////////////////////////////////////////////////////////////////////////////// -Variable::Variable(const ModelStorage* const storage, const VariableId id) - : storage_(storage), id_(id) { - DCHECK(storage != nullptr); -} - -int64_t Variable::id() const { return id_.value(); } - -VariableId Variable::typed_id() const { return id_; } - -const ModelStorage* Variable::storage() const { return storage_; } - double Variable::lower_bound() const { - return storage_->variable_lower_bound(id_); + return storage()->variable_lower_bound(typed_id()); } double Variable::upper_bound() const { - return storage_->variable_upper_bound(id_); + return storage()->variable_upper_bound(typed_id()); } -bool Variable::is_integer() const { return storage_->is_variable_integer(id_); } +bool Variable::is_integer() const { + return storage()->is_variable_integer(typed_id()); +} absl::string_view Variable::name() const { - if (storage()->has_variable(id_)) { - return storage_->variable_name(id_); + if (storage()->has_variable(typed_id())) { + return storage()->variable_name(typed_id()); } return "[variable deleted from model]"; } -template -H AbslHashValue(H h, const Variable& variable) { - return H::combine(std::move(h), variable.id_.value(), variable.storage_); -} - -std::ostream& operator<<(std::ostream& ostr, const Variable& variable) { - // TODO(b/170992529): handle quoting of invalid characters in the name. - const absl::string_view name = variable.name(); - if (name.empty()) { - ostr << "__var#" << variable.id() << "__"; - } else { - ostr << name; - } - return ostr; -} - LinearExpression Variable::operator-() const { return LinearExpression({LinearTerm(*this, -1.0)}, 0.0); } @@ -1368,17 +1301,9 @@ LinearTerm operator/(Variable variable, const double coefficient) { // LinearExpression //////////////////////////////////////////////////////////////////////////////// -void LinearExpression::SetOrCheckStorage(const ModelStorage* const storage) { - CHECK(storage != nullptr) << internal::kKeyHasNullModelStorage; - if (storage_ == nullptr) { - storage_ = storage; - return; - } - CHECK_EQ(storage, storage_) << internal::kObjectsFromOtherModelStorage; -} - LinearExpression::LinearExpression(LinearExpression&& other) noexcept - : storage_(std::exchange(other.storage_, nullptr)), + : ModelStorageItemContainer( + static_cast(other)), terms_(std::move(other.terms_)), offset_(std::exchange(other.offset_, 0.0)) { other.terms_.clear(); @@ -1389,7 +1314,8 @@ LinearExpression::LinearExpression(LinearExpression&& other) noexcept LinearExpression& LinearExpression::operator=( LinearExpression&& other) noexcept { - storage_ = std::exchange(other.storage_, nullptr); + ModelStorageItemContainer::operator=( + static_cast(other)); terms_ = std::move(other.terms_); other.terms_.clear(); offset_ = std::exchange(other.offset_, 0.0); @@ -1403,7 +1329,7 @@ LinearExpression::LinearExpression(std::initializer_list terms, ++num_calls_initializer_list_constructor_; #endif // MATH_OPT_USE_EXPRESSION_COUNTERS for (const auto& term : terms) { - SetOrCheckStorage(term.variable.storage()); + SetOrCheckStorage(term.variable); // The same variable may appear multiple times in the input list; we must // accumulate the coefficients. terms_[term.variable] += term.coefficient; @@ -1579,7 +1505,7 @@ LinearExpression& LinearExpression::operator+=(const LinearExpression& other) { // thus we don't need to compare in the loop. Of course this only applies if // the other has terms. if (!other.terms_.empty()) { - SetOrCheckStorage(other.storage()); + SetOrCheckStorage(other); for (const auto& [v, coeff] : other.terms_) { terms_[v] += coeff; } @@ -1589,13 +1515,13 @@ LinearExpression& LinearExpression::operator+=(const LinearExpression& other) { } LinearExpression& LinearExpression::operator+=(const LinearTerm& term) { - SetOrCheckStorage(term.variable.storage()); + SetOrCheckStorage(term.variable); terms_[term.variable] += term.coefficient; return *this; } LinearExpression& LinearExpression::operator+=(const Variable variable) { - SetOrCheckStorage(variable.storage()); + SetOrCheckStorage(variable); return *this += LinearTerm(variable, 1.0); } @@ -1607,7 +1533,7 @@ LinearExpression& LinearExpression::operator+=(const double value) { LinearExpression& LinearExpression::operator-=(const LinearExpression& other) { // See operator+=. if (!other.terms_.empty()) { - SetOrCheckStorage(other.storage()); + SetOrCheckStorage(other); for (const auto& [v, coeff] : other.terms_) { terms_[v] -= coeff; } @@ -1617,13 +1543,13 @@ LinearExpression& LinearExpression::operator-=(const LinearExpression& other) { } LinearExpression& LinearExpression::operator-=(const LinearTerm& term) { - SetOrCheckStorage(term.variable.storage()); + SetOrCheckStorage(term.variable); terms_[term.variable] -= term.coefficient; return *this; } LinearExpression& LinearExpression::operator-=(const Variable variable) { - SetOrCheckStorage(variable.storage()); + SetOrCheckStorage(variable); return *this -= LinearTerm(variable, 1.0); } @@ -1713,8 +1639,6 @@ const VariableMap& LinearExpression::terms() const { return terms_; } double LinearExpression::offset() const { return offset_; } -const ModelStorage* LinearExpression::storage() const { return storage_; } - //////////////////////////////////////////////////////////////////////////////// // VariablesEquality //////////////////////////////////////////////////////////////////////////////// @@ -2055,9 +1979,9 @@ BoundedLinearExpression operator==(const double lhs, const Variable rhs) { // QuadraticTermKey //////////////////////////////////////////////////////////////////////////////// -QuadraticTermKey::QuadraticTermKey(const ModelStorage* storage, +QuadraticTermKey::QuadraticTermKey(const ModelStorageCPtr storage, const QuadraticProductId id) - : storage_(storage), variable_ids_(id) { + : ModelStorageItem(storage), variable_ids_(id) { if (variable_ids_.first > variable_ids_.second) { // See https://en.cppreference.com/w/cpp/named_req/Swappable for details. using std::swap; @@ -2075,8 +1999,6 @@ QuadraticTermKey::QuadraticTermKey(const Variable first_variable, QuadraticProductId QuadraticTermKey::typed_id() const { return variable_ids_; } -const ModelStorage* QuadraticTermKey::storage() const { return storage_; } - template H AbslHashValue(H h, const QuadraticTermKey& key) { return H::combine(std::move(h), key.typed_id().first.value(), @@ -2124,17 +2046,9 @@ QuadraticTermKey QuadraticTerm::GetKey() const { // QuadraticExpression (no arithmetic) //////////////////////////////////////////////////////////////////////////////// -void QuadraticExpression::SetOrCheckStorage(const ModelStorage* const storage) { - CHECK(storage != nullptr) << internal::kKeyHasNullModelStorage; - if (storage_ == nullptr) { - storage_ = storage; - return; - } - CHECK_EQ(storage, storage_) << internal::kObjectsFromOtherModelStorage; -} - QuadraticExpression::QuadraticExpression(QuadraticExpression&& other) noexcept - : storage_(std::exchange(other.storage_, nullptr)), + : ModelStorageItemContainer( + static_cast(other)), quadratic_terms_(std::move(other.quadratic_terms_)), linear_terms_(std::move(other.linear_terms_)), offset_(std::exchange(other.offset_, 0.0)) { @@ -2147,7 +2061,8 @@ QuadraticExpression::QuadraticExpression(QuadraticExpression&& other) noexcept QuadraticExpression& QuadraticExpression::operator=( QuadraticExpression&& other) noexcept { - storage_ = std::exchange(other.storage_, nullptr); + ModelStorageItemContainer::operator=( + static_cast(other)); quadratic_terms_ = std::move(other.quadratic_terms_); other.quadratic_terms_.clear(); linear_terms_ = std::move(other.linear_terms_); @@ -2164,12 +2079,12 @@ QuadraticExpression::QuadraticExpression( ++num_calls_initializer_list_constructor_; #endif // MATH_OPT_USE_EXPRESSION_COUNTERS for (const LinearTerm& term : linear_terms) { - SetOrCheckStorage(term.variable.storage()); + SetOrCheckStorage(term.variable); linear_terms_[term.variable] += term.coefficient; } for (const QuadraticTerm& term : quadratic_terms) { const QuadraticTermKey key = term.GetKey(); - SetOrCheckStorage(key.storage()); + SetOrCheckStorage(key); quadratic_terms_[key] += term.coefficient(); } } @@ -2184,9 +2099,9 @@ QuadraticExpression::QuadraticExpression(const LinearTerm& term) : QuadraticExpression({}, {term}, 0.0) {} QuadraticExpression::QuadraticExpression(LinearExpression expr) - : storage_(std::exchange(expr.storage_, nullptr)), + : ModelStorageItemContainer(expr.storage()), linear_terms_(std::move(expr.terms_)), - offset_(std::exchange(expr.offset_, 0.0)) { + offset_(expr.offset_) { #ifdef MATH_OPT_USE_EXPRESSION_COUNTERS ++num_calls_linear_expression_constructor_; #endif // MATH_OPT_USE_EXPRESSION_COUNTERS @@ -2195,8 +2110,6 @@ QuadraticExpression::QuadraticExpression(LinearExpression expr) QuadraticExpression::QuadraticExpression(const QuadraticTerm& term) : QuadraticExpression({term}, {}, 0.0) {} -const ModelStorage* QuadraticExpression::storage() const { return storage_; } - double QuadraticExpression::offset() const { return offset_; } const VariableMap& QuadraticExpression::linear_terms() const { @@ -2582,13 +2495,13 @@ QuadraticExpression& QuadraticExpression::operator+=(const double value) { } QuadraticExpression& QuadraticExpression::operator+=(const Variable variable) { - SetOrCheckStorage(variable.storage()); + SetOrCheckStorage(variable); linear_terms_[variable] += 1; return *this; } QuadraticExpression& QuadraticExpression::operator+=(const LinearTerm& term) { - SetOrCheckStorage(term.variable.storage()); + SetOrCheckStorage(term.variable); linear_terms_[term.variable] += term.coefficient; return *this; } @@ -2598,7 +2511,7 @@ QuadraticExpression& QuadraticExpression::operator+=( offset_ += expr.offset(); // See comment in LinearExpression::operator+=. if (!expr.terms().empty()) { - SetOrCheckStorage(expr.storage()); + SetOrCheckStorage(expr); for (const auto& [v, coeff] : expr.terms()) { linear_terms_[v] += coeff; } @@ -2609,7 +2522,7 @@ QuadraticExpression& QuadraticExpression::operator+=( QuadraticExpression& QuadraticExpression::operator+=( const QuadraticTerm& term) { const QuadraticTermKey key = term.GetKey(); - SetOrCheckStorage(key.storage()); + SetOrCheckStorage(key); quadratic_terms_[key] += term.coefficient(); return *this; } @@ -2619,7 +2532,7 @@ QuadraticExpression& QuadraticExpression::operator+=( offset_ += expr.offset(); // See comment in LinearExpression::operator+=. if (!expr.linear_terms().empty() || !expr.quadratic_terms().empty()) { - SetOrCheckStorage(expr.storage()); + SetOrCheckStorage(expr); for (const auto& [v, coeff] : expr.linear_terms()) { linear_terms_[v] += coeff; } @@ -2637,13 +2550,13 @@ QuadraticExpression& QuadraticExpression::operator-=(const double value) { } QuadraticExpression& QuadraticExpression::operator-=(const Variable variable) { - SetOrCheckStorage(variable.storage()); + SetOrCheckStorage(variable); linear_terms_[variable] -= 1; return *this; } QuadraticExpression& QuadraticExpression::operator-=(const LinearTerm& term) { - SetOrCheckStorage(term.variable.storage()); + SetOrCheckStorage(term.variable); linear_terms_[term.variable] -= term.coefficient; return *this; } @@ -2653,7 +2566,7 @@ QuadraticExpression& QuadraticExpression::operator-=( offset_ -= expr.offset(); // See comment in LinearExpression::operator+=. if (!expr.terms().empty()) { - SetOrCheckStorage(expr.storage()); + SetOrCheckStorage(expr); for (const auto& [v, coeff] : expr.terms()) { linear_terms_[v] -= coeff; } @@ -2664,7 +2577,7 @@ QuadraticExpression& QuadraticExpression::operator-=( QuadraticExpression& QuadraticExpression::operator-=( const QuadraticTerm& term) { const QuadraticTermKey key = term.GetKey(); - SetOrCheckStorage(key.storage()); + SetOrCheckStorage(key); quadratic_terms_[key] -= term.coefficient(); return *this; } @@ -2674,7 +2587,7 @@ QuadraticExpression& QuadraticExpression::operator-=( offset_ -= expr.offset(); // See comment in LinearExpression::operator+=. if (!expr.linear_terms().empty() || !expr.quadratic_terms().empty()) { - SetOrCheckStorage(expr.storage()); + SetOrCheckStorage(expr); for (const auto& [v, coeff] : expr.linear_terms()) { linear_terms_[v] -= coeff; } diff --git a/ortools/math_opt/elemental/BUILD.bazel b/ortools/math_opt/elemental/BUILD.bazel new file mode 100644 index 0000000000..1fbe3dbf27 --- /dev/null +++ b/ortools/math_opt/elemental/BUILD.bazel @@ -0,0 +1,553 @@ +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +cc_library( + name = "attributes", + hdrs = ["attributes.h"], + visibility = ["//ortools/math_opt:__subpackages__"], + deps = [ + ":arrays", + ":elements", + ":symmetry", + "//ortools/base:array", + "@abseil-cpp//absl/strings:string_view", + ], +) + +cc_test( + name = "attributes_test", + srcs = ["attributes_test.cc"], + deps = [ + ":arrays", + ":attributes", + "//ortools/base:gmock_main", + "//ortools/math_opt/testing:stream", + "@abseil-cpp//absl/strings", + ], +) + +cc_library( + name = "elemental", + srcs = [ + "elemental.cc", + "elemental_export_model.cc", + "elemental_from_proto.cc", + "elemental_to_string.cc", + ], + hdrs = ["elemental.h"], + visibility = ["//ortools/math_opt:__subpackages__"], + deps = [ + ":arrays", + ":attr_key", + ":attr_storage", + ":attributes", + ":derived_data", + ":diff", + ":element_ref_tracker", + ":element_storage", + ":elements", + ":symmetry", + ":thread_safe_id_map", + "//ortools/base:status_macros", + "//ortools/math_opt:model_cc_proto", + "//ortools/math_opt:model_update_cc_proto", + "//ortools/math_opt:sparse_containers_cc_proto", + "//ortools/math_opt/core:model_summary", + "//ortools/math_opt/validators:model_validator", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/log:die_if_null", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:string_view", + "@abseil-cpp//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "elemental_test", + srcs = ["elemental_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":derived_data", + ":diff", + ":elemental", + ":elemental_matcher", + ":elements", + ":symmetry", + ":testing", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status", + "@com_google_benchmark//:benchmark", + ], +) + +cc_library( + name = "derived_data", + hdrs = ["derived_data.h"], + visibility = ["//ortools/math_opt:__subpackages__"], + deps = [ + ":arrays", + ":attr_key", + ":attributes", + ":elements", + "//ortools/util:fp_roundtrip_conv", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/strings", + ], +) + +cc_test( + name = "derived_data_test", + srcs = ["derived_data_test.cc"], + deps = [ + ":arrays", + ":attr_key", + ":attributes", + ":derived_data", + ":elements", + ":symmetry", + "//ortools/base:gmock_main", + "//ortools/math_opt/testing:stream", + ], +) + +cc_library( + name = "element_storage", + srcs = ["element_storage.cc"], + hdrs = ["element_storage.h"], + deps = [ + "//ortools/base:status_macros", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings:string_view", + ], +) + +cc_test( + name = "element_storage_test", + srcs = ["element_storage_test.cc"], + deps = [ + ":element_storage", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/status", + "@com_google_benchmark//:benchmark", + ], +) + +cc_library( + name = "element_diff", + hdrs = ["element_diff.h"], + deps = ["@abseil-cpp//absl/container:flat_hash_set"], +) + +cc_test( + name = "element_diff_test", + srcs = ["element_diff_test.cc"], + deps = [ + ":element_diff", + "//ortools/base:gmock_main", + ], +) + +cc_library( + name = "diff", + srcs = ["diff.cc"], + hdrs = ["diff.h"], + deps = [ + "derived_data", + ":attr_diff", + ":attr_key", + ":attributes", + ":element_diff", + ":elements", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/types:span", + ], +) + +cc_test( + name = "diff_test", + srcs = ["diff_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":diff", + ":elements", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/types:span", + ], +) + +cc_library( + name = "attr_storage", + hdrs = ["attr_storage.h"], + deps = [ + ":attr_key", + ":symmetry", + "//ortools/base:map_util", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/functional:function_ref", + ], +) + +cc_test( + name = "attr_storage_test", + srcs = ["attr_storage_test.cc"], + deps = [ + ":attr_key", + ":attr_storage", + ":symmetry", + "//ortools/base:gmock_main", + "@com_google_benchmark//:benchmark", + ], +) + +cc_library( + name = "attr_diff", + hdrs = ["attr_diff.h"], + deps = [ + ":attr_key", + "@abseil-cpp//absl/container:flat_hash_set", + ], +) + +cc_test( + name = "attr_diff_test", + srcs = ["attr_diff_test.cc"], + deps = [ + ":attr_diff", + ":attr_key", + ":symmetry", + "//ortools/base:gmock_main", + ], +) + +cc_library( + name = "attr_key", + hdrs = ["attr_key.h"], + visibility = ["//ortools/math_opt:__subpackages__"], + deps = [ + ":elements", + ":symmetry", + "//ortools/base:status_macros", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:span", + ], +) + +cc_test( + name = "attr_key_test", + srcs = ["attr_key_test.cc"], + deps = [ + ":attr_key", + ":elements", + ":symmetry", + ":testing", + "//ortools/base:gmock_main", + "//ortools/math_opt/testing:stream", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/hash:hash_testing", + "@abseil-cpp//absl/meta:type_traits", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/strings", + "@com_google_benchmark//:benchmark", + ], +) + +cc_library( + name = "arrays", + hdrs = ["arrays.h"], + visibility = ["//ortools/math_opt/elemental:__subpackages__"], +) + +cc_library( + name = "elemental_differencer", + srcs = ["elemental_differencer.cc"], + hdrs = ["elemental_differencer.h"], + deps = [ + ":attr_key", + ":attributes", + ":derived_data", + ":elemental", + ":elements", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + ], +) + +cc_test( + name = "elemental_differencer_test", + srcs = ["elemental_differencer_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":elemental", + ":elemental_differencer", + ":elements", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/container:flat_hash_set", + ], +) + +cc_test( + name = "arrays_test", + srcs = ["arrays_test.cc"], + deps = [ + ":arrays", + "//ortools/base:array", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:string_view", + ], +) + +cc_test( + name = "elemental_export_model_test", + srcs = ["elemental_export_model_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":derived_data", + ":elemental", + ":elements", + "//ortools/base:gmock_main", + "//ortools/math_opt:model_cc_proto", + "//ortools/math_opt:sparse_containers_cc_proto", + ], +) + +cc_test( + name = "elemental_to_string_test", + srcs = ["elemental_to_string_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":elemental", + ":elements", + "//ortools/base:gmock_main", + "//ortools/math_opt/testing:stream", + "@abseil-cpp//absl/strings", + ], +) + +cc_test( + name = "safe_attr_ops_test", + srcs = ["safe_attr_ops_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":elemental", + ":elements", + ":safe_attr_ops", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/status", + ], +) + +cc_library( + name = "safe_attr_ops", + hdrs = ["safe_attr_ops.h"], + visibility = ["//ortools/math_opt/elemental/c_api:__subpackages__"], + deps = [ + ":derived_data", + ":elemental", + "//ortools/base:status_macros", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + ], +) + +cc_library( + name = "testing", + testonly = 1, + hdrs = ["testing.h"], + deps = [":attr_key"], +) + +cc_library( + name = "symmetry", + hdrs = ["symmetry.h"], + deps = [ + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/strings:string_view", + ], +) + +cc_library( + name = "elemental_matcher", + testonly = 1, + srcs = ["elemental_matcher.cc"], + hdrs = ["elemental_matcher.h"], + deps = [ + ":elemental", + ":elemental_differencer", + "//ortools/base:gmock", + "@abseil-cpp//absl/base:core_headers", + ], +) + +cc_library( + name = "element_ref_tracker", + hdrs = ["element_ref_tracker.h"], + deps = [ + ":attr_key", + ":elements", + "//ortools/base:map_util", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + ], +) + +cc_library( + name = "elements", + srcs = ["elements.cc"], + hdrs = ["elements.h"], + visibility = ["//ortools/math_opt:__subpackages__"], + deps = [ + "//ortools/base:array", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/strings:string_view", + ], +) + +cc_test( + name = "elements_test", + srcs = ["elements_test.cc"], + deps = [ + ":elements", + "//ortools/base:gmock_main", + "//ortools/math_opt/testing:stream", + "@abseil-cpp//absl/hash:hash_testing", + "@abseil-cpp//absl/strings", + ], +) + +cc_test( + name = "elemental_matcher_test", + srcs = ["elemental_matcher_test.cc"], + deps = [ + ":elemental", + ":elemental_differencer", + ":elemental_matcher", + ":elements", + "//ortools/base:gmock_main", + ], +) + +cc_test( + name = "element_ref_tracker_test", + srcs = ["element_ref_tracker_test.cc"], + deps = [ + ":attr_key", + ":element_ref_tracker", + ":elements", + ":symmetry", + "//ortools/base:gmock_main", + ], +) + +cc_library( + name = "thread_safe_id_map", + hdrs = ["thread_safe_id_map.h"], + deps = [ + "//ortools/base:stl_util", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/synchronization", + "@abseil-cpp//absl/types:span", + ], +) + +cc_test( + name = "thread_safe_id_map_test", + srcs = ["thread_safe_id_map_test.cc"], + deps = [ + ":thread_safe_id_map", + "//ortools/base:gmock_main", + ], +) + +cc_test( + name = "elemental_from_proto_test", + srcs = ["elemental_from_proto_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":derived_data", + ":elemental", + ":elemental_matcher", + ":elements", + "//ortools/base:gmock_main", + "//ortools/math_opt:model_cc_proto", + "//ortools/math_opt:sparse_containers_cc_proto", + "@abseil-cpp//absl/status", + ], +) + +cc_test( + name = "elemental_from_proto_fuzz_test", + srcs = ["elemental_from_proto_fuzz_test.cc"], + tags = ["componentid:1147829"], + deps = [ + ":elemental", + ":elemental_matcher", + "//ortools/base:fuzztest", + "//ortools/base:gmock_main", + "//ortools/math_opt:model_update_cc_proto", + "@abseil-cpp//absl/status:statusor", + ], +) + +cc_test( + name = "elemental_update_from_proto_test", + srcs = ["elemental_update_from_proto_test.cc"], + deps = [ + ":attr_key", + ":attributes", + ":derived_data", + ":elemental", + ":elemental_matcher", + ":elements", + "//ortools/base:gmock_main", + "//ortools/math_opt:model_cc_proto", + "//ortools/math_opt:model_update_cc_proto", + "//ortools/math_opt:sparse_containers_cc_proto", + "@abseil-cpp//absl/status", + ], +) diff --git a/ortools/math_opt/elemental/CMakeLists.txt b/ortools/math_opt/elemental/CMakeLists.txt new file mode 100644 index 0000000000..cf5db2bacd --- /dev/null +++ b/ortools/math_opt/elemental/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(NAME ${PROJECT_NAME}_math_opt_elemental) +add_library(${NAME} OBJECT) + +file(GLOB_RECURSE _SRCS "*.h" "*.cc") +list(FILTER _SRCS EXCLUDE REGEX ".*/.*_test.cc") +list(FILTER _SRCS EXCLUDE REGEX "/elemental_matcher.*") +list(FILTER _SRCS EXCLUDE REGEX "/python/.*") +list(FILTER _SRCS EXCLUDE REGEX "/codegen/codegen.cc") +target_sources(${NAME} PRIVATE ${_SRCS}) +set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(${NAME} PUBLIC + $ + $) +target_link_libraries(${NAME} PRIVATE + ${PROJECT_NAMESPACE}::math_opt_proto + absl::strings +) diff --git a/ortools/math_opt/elemental/README.md b/ortools/math_opt/elemental/README.md new file mode 100644 index 0000000000..13df2a7fd1 --- /dev/null +++ b/ortools/math_opt/elemental/README.md @@ -0,0 +1,3 @@ +# Elemental + +See go/math-opt-elemental and g/math-opt-dev/c/0cgOO6qkoWM. diff --git a/ortools/math_opt/elemental/arrays.h b/ortools/math_opt/elemental/arrays.h new file mode 100644 index 0000000000..c7cc163709 --- /dev/null +++ b/ortools/math_opt/elemental/arrays.h @@ -0,0 +1,74 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities to apply template functors on index ranges. +// See tests for examples. +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_ + +#include +#include + +namespace operations_research::math_opt { + +// Calls `fn<0, ..., n-1>()`, and returns the result. Typically used for +// simple reduce operations that can be expressed as a fold. +// +// Examples: +// - Sum of elements from 0 to 5 (result is 15): +// `ApplyOnIndexRange<6>([]() { return (i + ... + 0); });` +// +// - Sum of elements of array `a`: +// ``` +// ApplyOnIndexRange([&a]() { +// return (a[i] + ... + 0); +// }); +// ``` +template +constexpr decltype(auto) ApplyOnIndexRange(Fn&& fn) { + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + return [&fn](std::integer_sequence) mutable { + return fn.template operator()(); + }(std::make_integer_sequence()); +} + +// Calls (fn<0>(), ..., fn()). +// Typically used for independent operations on elements, or more complex reduce +// operations that cannot be expressed with a fold. +// +// Example (independent operations): Log each array element for some array `a`: +// `ForEachIndex([&a]() { LOG(ERROR) << a[i]; });` +// +// NOTE: this returns the result of the last call, which allows returning some +// internal state (and avoids capturing an external variable by reference) for +// complex fold operations. See `CollectTest` for an example. +template +constexpr decltype(auto) ForEachIndex(Fn&& fn) { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + [&fn]() { return (fn.template operator()(), ...); }); +} + +// Calls `fn` of each element of `tuple`, and returns the result of the +// last invocation. +template +constexpr decltype(auto) ForEach(Fn&& fn, Tuple&& tuple) { + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + return std::apply([&fn]( + Ts&&... ts) { return (fn(std::forward(ts)), ...); }, + std::forward(tuple)); +} + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_ diff --git a/ortools/math_opt/elemental/arrays_test.cc b/ortools/math_opt/elemental/arrays_test.cc new file mode 100644 index 0000000000..865b28caeb --- /dev/null +++ b/ortools/math_opt/elemental/arrays_test.cc @@ -0,0 +1,179 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/arrays.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" +#include "ortools/base/array.h" +#include "ortools/base/gmock.h" + +namespace operations_research::math_opt { +namespace { + +using ::testing::ElementsAre; + +// Sums the elements of an array-like object `a`. +template +constexpr int Sum() { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + []() { return (a[i] + ... + 0); }); +} + +// Same as `Sum`, but starts at 1. +template +constexpr int SumPlusOne() { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + []() { return (a[i] + ... + 1); }); +} + +#if __cplusplus >= 202002L +// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat) +TEST(ApplyOnIndexRangeTest, Sum) { + EXPECT_EQ(Sum(), 9); + EXPECT_EQ(SumPlusOne(), 10); +} +// NOLINTEND(clang-diagnostic-pre-c++20-compat) +#endif + +// Returns the weighted sum of the elements of an array-like object `a`, where +// weights are indices. +template +constexpr double ScaledSum() { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + []() { return ((i * a[i]) + ... + 0.0); }); +} + +#if __cplusplus >= 202002L +// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat) +TEST(ApplyOnIndexRangeTest, ScaledSum) { + EXPECT_EQ(ScaledSum(), 5.0); +} +// NOLINTEND(clang-diagnostic-pre-c++20-compat) +#endif + +// Returns the number of even elements in an array-like object `a`. +template +constexpr int CountEven() { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + []() { return ((a[i] % 2 == 0 ? 1 : 0) + ... + 0); }); +} + +#if __cplusplus >= 202002L +// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat) +TEST(ApplyOnIndexRangeTest, CountEven) { + EXPECT_EQ(CountEven(), 3); +} +// NOLINTEND(clang-diagnostic-pre-c++20-compat) +#endif + +// Returns array of doubles of the same size as `a`, where each element has been +// halved. +template +constexpr std::array Half() { + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + return ApplyOnIndexRange([]() { + return std::array( + {(static_cast(a[i]) / 2.0)...}); + }); +} + +#if __cplusplus >= 202002L +// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat) +TEST(ApplyOnIndexRangeTest, Half) { + EXPECT_THAT(Half(), + ElementsAre(2.5, 2.0, 4.0, 0.5, 5.0)); +} +// NOLINTEND(clang-diagnostic-pre-c++20-compat) +#endif + +// Returns true of all elements of `a` are even. +template +constexpr int AllEven() { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + []() { return (((a[i] % 2) == 0) && ...); }); +} + +// Returns true of any element of `a` is even. +template +constexpr int AnyEven() { + return ApplyOnIndexRange( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + []() { return (((a[i] % 2) == 0) || ...); }); +} + +#if __cplusplus >= 202002L +// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat) +TEST(ApplyOnIndexRangeTest, Even) { + EXPECT_FALSE(AllEven()); + EXPECT_TRUE(AnyEven()); + + EXPECT_TRUE(AllEven()); + EXPECT_TRUE(AnyEven()); + + EXPECT_FALSE(AllEven()); + EXPECT_FALSE(AnyEven()); +} +// NOLINTEND(clang-diagnostic-pre-c++20-compat) +#endif + +// A example of a more complex reduce operation using `ForEachIndex`. Here, we +// want to collect a list of integers for which an operation (`may_fail`) +// failed. +TEST(ForEachIndexTest, CollectTest) { + constexpr auto may_fail = [](int i) { + if (i == 3 || i == 7 || i == 42) { + return absl::InvalidArgumentError("bad number"); + } + return absl::OkStatus(); + }; + + EXPECT_THAT( + ForEachIndex<21>( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + [&may_fail, failed_indices = std::vector()]() mutable + -> const std::vector& { + if (!may_fail(i).ok()) { + failed_indices.push_back(i); + } + return failed_indices; + }), + ElementsAre(3, 7)); +} + +TEST(ForEachTest, StrCatHeterogeneousTypes) { + EXPECT_EQ( + ForEach( + [r = std::string()](const auto& v) mutable -> absl::string_view { + absl::StrAppend(&r, " ", v); + return r; + }, + std::make_tuple("a", 1, 0.5)), + " a 1 0.5"); +} + +} // namespace +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/elemental/attr_diff.h b/ortools/math_opt/elemental/attr_diff.h new file mode 100644 index 0000000000..4417e8b948 --- /dev/null +++ b/ortools/math_opt/elemental/attr_diff.h @@ -0,0 +1,58 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "ortools/math_opt/elemental/attr_key.h" + +namespace operations_research::math_opt { + +// Tracks modifications to an Attribute with a key size of n (e.g., variable +// lower bound has a key size of 1). +template +class AttrDiff { + public: + using Key = AttrKey; + + // On creation, the attribute is not modified for any key. + AttrDiff() = default; + + // Clear all tracked modifications. + void Advance() { modified_keys_.clear(); } + + // Mark the attribute as modified for `key`. + void SetModified(const Key key) { modified_keys_.insert(key); } + + // Returns the attribute keys that have been modified for this attribute (the + // elements where set_modified() was called without a subsequent call to + // Advance()). + const AttrKeyHashSet& modified_keys() const { return modified_keys_; } + + bool has_modified_keys() const { return !modified_keys_.empty(); } + + // Stop tracking modifications for this attribute key. (Typically invoked when + // an element in the key was deleted from the model.) + void Erase(const Key key) { modified_keys_.erase(key); } + + private: + AttrKeyHashSet modified_keys_; +}; + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_ diff --git a/ortools/math_opt/elemental/attr_diff_test.cc b/ortools/math_opt/elemental/attr_diff_test.cc new file mode 100644 index 0000000000..efc326dc21 --- /dev/null +++ b/ortools/math_opt/elemental/attr_diff_test.cc @@ -0,0 +1,168 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/attr_diff.h" + +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/math_opt/elemental/attr_key.h" +#include "ortools/math_opt/elemental/symmetry.h" + +namespace operations_research::math_opt { + +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +//////////////////////////////////////////////////////////////////////////////// +// AttrDiff<0> +//////////////////////////////////////////////////////////////////////////////// + +TEST(AttrDiff0Test, InitNotModified) { + AttrDiff<0, NoSymmetry> diff; + EXPECT_THAT(diff.modified_keys(), IsEmpty()); +} + +TEST(AttrDiff0Test, SetModified) { + AttrDiff<0, NoSymmetry> diff; + diff.SetModified(AttrKey()); + EXPECT_THAT(diff.modified_keys(), UnorderedElementsAre(AttrKey())); +} + +TEST(AttrDiff0Test, Advance) { + AttrDiff<0, NoSymmetry> diff; + diff.SetModified(AttrKey()); + diff.Advance(); + EXPECT_THAT(diff.modified_keys(), IsEmpty()); +} + +//////////////////////////////////////////////////////////////////////////////// +// Attr1Diff +//////////////////////////////////////////////////////////////////////////////// + +TEST(AttrDiff1Test, InitNotModified) { + AttrDiff<1, NoSymmetry> diff; + EXPECT_THAT(diff.modified_keys(), IsEmpty()); +} + +TEST(AttrDiff1Test, SetModified) { + AttrDiff<1, NoSymmetry> diff; + diff.SetModified(AttrKey(2)); + diff.SetModified(AttrKey(5)); + diff.SetModified(AttrKey(6)); + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(AttrKey(2), AttrKey(5), AttrKey(6))); +} + +TEST(AttrDiff1Test, Advance) { + AttrDiff<1, NoSymmetry> diff; + diff.SetModified(AttrKey(2)); + diff.SetModified(AttrKey(5)); + + diff.Advance(); + EXPECT_THAT(diff.modified_keys(), IsEmpty()); +} + +TEST(AttrDiff1Test, EraseIsModifiedGetsRemoved) { + AttrDiff<1, NoSymmetry> diff; + diff.SetModified(AttrKey(2)); + diff.SetModified(AttrKey(5)); + diff.SetModified(AttrKey(6)); + + diff.Erase(AttrKey(5)); + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(AttrKey(2), AttrKey(6))); +} + +TEST(AttrDiff1Test, EraseNotModifiedNoEffect) { + AttrDiff<1, NoSymmetry> diff; + diff.SetModified(AttrKey(2)); + diff.SetModified(AttrKey(5)); + + diff.Erase(AttrKey(1)); + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(AttrKey(2), AttrKey(5))); +} + +//////////////////////////////////////////////////////////////////////////////// +// Attr2Diff +//////////////////////////////////////////////////////////////////////////////// + +TEST(AttrDiffTest2, InitNotModified) { + AttrDiff<2, NoSymmetry> diff; + EXPECT_THAT(diff.modified_keys(), IsEmpty()); +} + +TEST(AttrDiffTest2, SetModified) { + AttrDiff<2, NoSymmetry> diff; + diff.SetModified(AttrKey(2, 4)); + diff.SetModified(AttrKey(5, 2)); + diff.SetModified(AttrKey(2, 5)); + diff.SetModified(AttrKey(6, 6)); + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(AttrKey(2, 4), AttrKey(5, 2), AttrKey(2, 5), + AttrKey(6, 6))); +} + +TEST(AttrDiffTest2, Advance) { + AttrDiff<2, NoSymmetry> diff; + diff.SetModified(AttrKey(2, 3)); + diff.SetModified(AttrKey(2, 8)); + + diff.Advance(); + EXPECT_THAT(diff.modified_keys(), IsEmpty()); +} + +TEST(AttrDiffTest2, EraseIsModifiedGetsRemoved) { + AttrDiff<2, NoSymmetry> diff; + diff.SetModified(AttrKey(2, 5)); + diff.SetModified(AttrKey(4, 3)); + diff.SetModified(AttrKey(3, 4)); + diff.SetModified(AttrKey(6, 6)); + + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(AttrKey(2, 5), AttrKey(3, 4), AttrKey(4, 3), + AttrKey(6, 6))); + + diff.Erase(AttrKey(4, 3)); + EXPECT_THAT( + diff.modified_keys(), + UnorderedElementsAre(AttrKey(2, 5), AttrKey(3, 4), AttrKey(6, 6))); +} + +TEST(AttrDiffTest2, EraseIsModifiedGetsRemovedSymmetric) { + using Diff = AttrDiff<2, ElementSymmetry<0, 1>>; + using Key = Diff::Key; + Diff diff; + diff.SetModified(Key(2, 5)); + diff.SetModified(Key(4, 3)); + diff.SetModified(Key(3, 4)); // Noop, same as (4,3). + diff.SetModified(Key(6, 6)); + + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(Key(2, 5), Key(3, 4), Key(6, 6))); + + diff.Erase(Key(4, 3)); + EXPECT_THAT(diff.modified_keys(), UnorderedElementsAre(Key(2, 5), Key(6, 6))); +} + +TEST(AttrDiffTest2, EraseNotModifiedNoEffect) { + AttrDiff<2, NoSymmetry> diff; + diff.SetModified(AttrKey(2, 5)); + diff.SetModified(AttrKey(6, 6)); + + diff.Erase(AttrKey(1, 3)); + EXPECT_THAT(diff.modified_keys(), + UnorderedElementsAre(AttrKey(2, 5), AttrKey(6, 6))); +} + +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/elemental/attr_key.h b/ortools/math_opt/elemental/attr_key.h new file mode 100644 index 0000000000..294788e27c --- /dev/null +++ b/ortools/math_opt/elemental/attr_key.h @@ -0,0 +1,360 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "ortools/base/status_builder.h" +#include "ortools/math_opt/elemental/elements.h" +#include "ortools/math_opt/elemental/symmetry.h" + +namespace operations_research::math_opt { + +// An attribute key for an attribute keyed on `n` elements. +// `AttrKey` is a value type. +template +class AttrKey { + public: + using value_type = int64_t; + using SymmetryT = Symmetry; + + // Default constructor: values are uninitialized. + constexpr AttrKey() {} // NOLINT: uninitialized on purpose. + + template && ...))>> + explicit constexpr AttrKey(const Ints... ids) { + auto push_back = [this, i = 0](auto e) mutable { element_ids_[i++] = e; }; + (push_back(ids), ...); + Symmetry::Enforce(element_ids_); + } + + template > + explicit constexpr AttrKey(const ElementId... ids) + : AttrKey(ids.value()...) {} + + constexpr AttrKey(std::array ids) // NOLINT: pybind11. + : element_ids_(ids) { + Symmetry::Enforce(element_ids_); + } + + // Canonicalizes a non-canonical key, i.e., enforces the symmetries + static constexpr AttrKey Canonicalize(AttrKey key) { + return AttrKey(key.element_ids_); + } + + // Creates a key from a range of `n` elements. + static absl::StatusOr FromRange(absl::Span range) { + if (range.size() != n) { + return ::util::InvalidArgumentErrorBuilder() + << "cannot build AttrKey<" << n << "> from a range of size " + << range.size(); + } + AttrKey result; + std::copy(range.begin(), range.end(), result.element_ids_.begin()); + Symmetry::Enforce(result.element_ids_); + return result; + } + + constexpr AttrKey(const AttrKey&) = default; + constexpr AttrKey& operator=(const AttrKey&) = default; + + static constexpr int size() { return n; } + + // Element access. + constexpr value_type operator[](const int dim) const { + DCHECK_LT(dim, n); + DCHECK_GE(dim, 0); + return element_ids_[dim]; + } + constexpr value_type& operator[](const int dim) { + DCHECK_LT(dim, n); + DCHECK_GE(dim, 0); + return element_ids_[dim]; + } + + // Element iteration. + constexpr const value_type* begin() const { return element_ids_.begin(); } + constexpr const value_type* end() const { return element_ids_.end(); } + + // `AttrKey` is comparable (ordering is lexicographic) and hashable. + // + // TODO(b/365998156): post C++ 20, replace by spaceship operator (with all + // comparison operators below). Do NOT use the default generated operator (see + // below). + friend constexpr bool operator==(const AttrKey& l, const AttrKey& r) { + // This is much faster than using the default generated `operator==`. + for (int i = 0; i < n; ++i) { + if (l.element_ids_[i] != r.element_ids_[i]) { + return false; + } + } + return true; + } + + friend constexpr bool operator<(const AttrKey& l, const AttrKey& r) { + // This is much faster than using the default generated `operator<`. + for (int i = 0; i < n; ++i) { + if (l.element_ids_[i] < r.element_ids_[i]) { + return true; + } + if (l.element_ids_[i] > r.element_ids_[i]) { + return false; + } + } + return false; + } + + friend constexpr bool operator<=(const AttrKey& l, const AttrKey& r) { + // This is much faster than using the default generated `operator<`. + for (int i = 0; i < n; ++i) { + if (l.element_ids_[i] < r.element_ids_[i]) { + return true; + } + if (l.element_ids_[i] > r.element_ids_[i]) { + return false; + } + } + return true; + } + + friend constexpr bool operator>(const AttrKey& l, const AttrKey& r) { + return r < l; + } + + friend constexpr bool operator>=(const AttrKey& l, const AttrKey& r) { + return r <= l; + } + + template + friend H AbslHashValue(H h, const AttrKey& a) { + return H::combine_contiguous(std::move(h), a.element_ids_.data(), n); + } + + // `AttrKey` is printable for logging and tests. + template + friend void AbslStringify(Sink& sink, const AttrKey& key) { + sink.Append(absl::StrCat( + "AttrKey(", absl::StrJoin(absl::MakeSpan(key.element_ids_), ", "), + ")")); + } + + // Removes the element at dimension `dim` from the key and returns a key with + // only remaining dimensions. + template + AttrKey RemoveElement() const { + static_assert(dim >= 0); + static_assert(dim < n); + AttrKey result; + for (int i = 0; i < dim; ++i) { + result.element_ids_[i] = element_ids_[i]; + } + for (int i = dim + 1; i < n; ++i) { + result.element_ids_[i - 1] = element_ids_[i]; + } + return result; + } + + // Adds element `elem` at dimension `dim` and returns the result. + // The result must respect `NewSymmetry` (we `DCHECK` this). + template + AttrKey AddElement(const value_type elem) const { + static_assert(dim >= 0); + static_assert(dim < n + 1); + AttrKey result; + for (int i = 0; i < dim; ++i) { + result.element_ids_[i] = element_ids_[i]; + } + result.element_ids_[dim] = elem; + for (int i = dim + 1; i < n + 1; ++i) { + result.element_ids_[i] = element_ids_[i - 1]; + } + DCHECK(NewSymmetry::Validate(result.element_ids_)) + << result << " does not have `" << NewSymmetry::GetName() + << "` symmetry"; + return result; + } + + private: + template + friend class AttrKey; + std::array element_ids_; +}; + +// CTAD for `AttrKey(1,2)`. +template +AttrKey(Ints... dims) -> AttrKey; + +// Traits to detect whether `T` is an `AttrKey`. +template +struct is_attr_key : public std::false_type {}; + +template +struct is_attr_key> : public std::true_type {}; + +template +static constexpr inline bool is_attr_key_v = is_attr_key::value; + +// Required for open-source `StatusBuilder` support. +template +std::ostream& operator<<(std::ostream& ostr, const AttrKey& key) { + ostr << absl::StrCat(key); + return ostr; +} + +namespace detail { +// A set of zero or one `AttrKey<0, Symmetry>, V`. This is used to make +// implementations of `AttrDiff` and `AttrStorage` uniform. +// `V` must by default constructible, trivially destructible and copyable +// (we'll fail to compile otherwise). +// After c++26, optional is a sequence container, so this can pretty much become +// `std::optional>` + `find()`. +template +class AttrKey0RawSet { + public: + using value_type = V; + using Key = AttrKey<0, Symmetry>; + + template + class IteratorImpl { + public: + IteratorImpl() = default; + // `iterator` converts to `const_iterator`. + IteratorImpl(const IteratorImpl>& other) // NOLINT + : value_(other.value_) {} + + // Dereference. + ValueT& operator*() const { + DCHECK_NE(value_, nullptr); + return *value_; + } + ValueT* operator->() const { + DCHECK_NE(value_, nullptr); + return value_; + } + + // Increment. + IteratorImpl& operator++() { + DCHECK_NE(value_, nullptr); + value_ = nullptr; + return *this; + } + + // Equality. + friend bool operator==(const IteratorImpl& l, const IteratorImpl& r) { + return l.value_ == r.value_; + } + friend bool operator!=(const IteratorImpl& l, const IteratorImpl& r) { + return !(l == r); + } + + private: + friend class AttrKey0RawSet; + explicit IteratorImpl(ValueT& value) : value_(&value) {} + + ValueT* value_ = nullptr; + }; + + using iterator = IteratorImpl; + using const_iterator = IteratorImpl; + + AttrKey0RawSet() = default; + + bool empty() const { return !engaged_; } + size_t size() const { return engaged_ ? 1 : 0; } + + const_iterator begin() const { + return engaged_ ? const_iterator(value_) : const_iterator(); + } + const_iterator end() const { return const_iterator(); } + iterator begin() { return engaged_ ? iterator(value_) : iterator(); } + iterator end() { return iterator(); } + + bool contains(Key) const { return engaged_; } + const_iterator find(Key) const { return begin(); } + iterator find(Key) { return begin(); } + + void clear() { engaged_ = false; } + size_t erase(Key) { + if (engaged_) { + engaged_ = false; + return 1; + } + return 0; + } + size_t erase(const_iterator) { return erase(Key()); } + + template + std::pair try_emplace(Key, Args&&... args) { + if (engaged_) { + return std::make_pair(iterator(value_), false); + } + value_ = value_type(Key(), std::forward(args)...); + engaged_ = true; + return std::make_pair(iterator(value_), true); + } + + std::pair insert(const value_type& v) { + if (engaged_) { + return std::make_pair(iterator(value_), false); + } + value_ = v; + engaged_ = true; + return std::make_pair(iterator(value_), true); + } + + private: + // The following greatly simplifies the implementation because we don't have + // to worry about side effects of the dtor (see e.g. `clear()`). + static_assert(std::is_trivially_destructible_v); + + bool engaged_ = false; + value_type value_; +}; + +} // namespace detail + +// A hash set of `AttrKeyT`, where `AttrKeyT` is an `AttrKey`. +template >> +using AttrKeyHashSet = std::conditional_t< + (AttrKeyT::size() > 0), absl::flat_hash_set, + detail::AttrKey0RawSet>; + +// A hash map of `AttrKeyT` to `V`, where `AttrKeyT` is an +// `AttrKey`. +template >> +using AttrKeyHashMap = + std::conditional_t<(AttrKeyT::size() > 0), absl::flat_hash_map, + detail::AttrKey0RawSet>>; + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_ diff --git a/ortools/math_opt/elemental/attr_key_test.cc b/ortools/math_opt/elemental/attr_key_test.cc new file mode 100644 index 0000000000..8e630f26fa --- /dev/null +++ b/ortools/math_opt/elemental/attr_key_test.cc @@ -0,0 +1,335 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/attr_key.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash_testing.h" +#include "absl/meta/type_traits.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/math_opt/elemental/elements.h" +#include "ortools/math_opt/elemental/symmetry.h" +#include "ortools/math_opt/elemental/testing.h" +#include "ortools/math_opt/testing/stream.h" + +namespace operations_research::math_opt { +namespace { +using testing::ElementsAre; +using testing::HasSubstr; +using testing::IsEmpty; +using testing::Pair; +using testing::SizeIs; +using testing::UnorderedElementsAre; +using testing::status::IsOkAndHolds; +using testing::status::StatusIs; + +static_assert(sizeof(AttrKey<0>) <= sizeof(uint64_t)); +static_assert(sizeof(AttrKey<1>) == sizeof(uint64_t)); +static_assert(sizeof(AttrKey<2>) == 2 * sizeof(uint64_t)); +static_assert(sizeof(AttrKey<2, ElementSymmetry<0, 1>>) == + 2 * sizeof(uint64_t)); + +// Make sure that passing AttrKey by value really puts it in registers rather +// than leaving it in the caller's frame (see +// https://itanium-cxx-abi.github.io/cxx-abi/abi.html#non-trivial). +static_assert(absl::is_trivially_relocatable>()); +static_assert(absl::is_trivially_relocatable>()); +static_assert(absl::is_trivially_relocatable>()); +static_assert( + absl::is_trivially_relocatable>>()); + +TEST(AttrKeyTest, CtorAndIteration) { + EXPECT_THAT(AttrKey(), ElementsAre()); + EXPECT_THAT(AttrKey(1), ElementsAre(1)); + EXPECT_THAT(AttrKey(1, 2), ElementsAre(1, 2)); +} + +TEST(AttrKeyTest, ElementIdCtor) { + EXPECT_THAT(AttrKey(ElementId(1)), ElementsAre(1)); + EXPECT_THAT(AttrKey(ElementId(1), + ElementId(2)), + ElementsAre(1, 2)); +} + +TEST(AttrKeyTest, ElementAccess) { + const AttrKey key(1, 2); + EXPECT_EQ(key[0], 1); + EXPECT_EQ(key[1], 2); + + AttrKey mutable_key(1, 2); + EXPECT_EQ(mutable_key[0], 1); + EXPECT_EQ(mutable_key[1], 2); +} + +TEST(AttrKeyTest, SupportsAbslHash1) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + AttrKey(1), + AttrKey(2), + AttrKey(0), + })); +} + +TEST(AttrKeyTest, SupportsAbslHash2) { + EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ + AttrKey(1, 2), + AttrKey(2, 3), + AttrKey(0, 0), + })); +} + +TEST(AttrKeyTest, Stringify) { + EXPECT_EQ(absl::StrCat(AttrKey(1, 2, 3)), "AttrKey(1, 2, 3)"); + EXPECT_EQ(StreamToString(AttrKey(1, 2, 3)), "AttrKey(1, 2, 3)"); +} + +TEST(AttrKeyTest, AddRemove) { + const AttrKey key0; + EXPECT_THAT(key0, ElementsAre()); + const AttrKey key1 = key0.AddElement<0, NoSymmetry>(3); + EXPECT_THAT(key1, ElementsAre(3)); + const AttrKey key2 = key1.AddElement<0, NoSymmetry>(1); + EXPECT_THAT(key2, ElementsAre(1, 3)); + const AttrKey key3 = key2.AddElement<1, NoSymmetry>(2); + EXPECT_THAT(key3, ElementsAre(1, 2, 3)); + const AttrKey key4 = key3.AddElement<3, NoSymmetry>(4); + EXPECT_THAT(key4, ElementsAre(1, 2, 3, 4)); +} + +TEST(AttrKeyTest, AddRemoveNotSymmetric) { + using NoSym = NoSymmetry; + EXPECT_THAT((AttrKey(0, 2).AddElement<1, NoSym>(1)), ElementsAre(0, 1, 2)); + EXPECT_THAT((AttrKey(0, 1).AddElement<2, NoSym>(2)), ElementsAre(0, 1, 2)); + EXPECT_THAT((AttrKey(0, 1).AddElement<1, NoSym>(2)), ElementsAre(0, 2, 1)); + EXPECT_THAT((AttrKey(0, 2).AddElement<2, NoSym>(1)), ElementsAre(0, 2, 1)); +} + +TEST(AttrKeyDeathTest, AddRemoveSymmetric) { + using Sym01 = ElementSymmetry<1, 2>; + EXPECT_THAT((AttrKey(0, 2).AddElement<1, Sym01>(1)), ElementsAre(0, 1, 2)); + EXPECT_THAT((AttrKey(0, 1).AddElement<2, Sym01>(2)), ElementsAre(0, 1, 2)); +#ifndef NDEBUG + EXPECT_DEATH( + (AttrKey(0, 1).AddElement<1, Sym01>(2)), + HasSubstr( + "AttrKey(0, 2, 1) does not have `ElementSymmetry<1, 2>` symmetry")); + EXPECT_DEATH( + (AttrKey(0, 2).AddElement<2, Sym01>(1)), + HasSubstr( + "AttrKey(0, 2, 1) does not have `ElementSymmetry<1, 2>` symmetry")); +#endif +} + +TEST(AttrKeyTest, ComparisonOperators) { + // a[0] < a[1] < a[2] < a[3] < a[4] + const std::vector> a = {AttrKey(1, 0, 0, 0), AttrKey(2, 5, 1, 12), + AttrKey(2, 5, 3, 10), AttrKey(2, 5, 3, 11), + AttrKey(3, 0, 0, 0)}; + + // Now test each of the operators + for (int i = 0; i < a.size(); ++i) { + SCOPED_TRACE(absl::StrCat(i)); + for (int j = 0; j < a.size(); ++j) { + SCOPED_TRACE(absl::StrCat(j)); + if (i == j) { + EXPECT_FALSE(a[i] < a[j]); + EXPECT_TRUE(a[i] <= a[j]); + EXPECT_TRUE(a[i] == a[j]); + EXPECT_TRUE(a[i] >= a[j]); + EXPECT_FALSE(a[i] > a[j]); + } else if (i < j) { + EXPECT_TRUE(a[i] < a[j]); + EXPECT_TRUE(a[i] <= a[j]); + EXPECT_FALSE(a[i] == a[j]); + EXPECT_FALSE(a[i] >= a[j]); + EXPECT_FALSE(a[i] > a[j]); + } else { + EXPECT_FALSE(a[i] < a[j]); + EXPECT_FALSE(a[i] <= a[j]); + EXPECT_FALSE(a[i] == a[j]); + EXPECT_TRUE(a[i] >= a[j]); + EXPECT_TRUE(a[i] > a[j]); + } + } + } +} + +TEST(AttrKey0SetTest, Works) { + AttrKeyHashSet> set; + + EXPECT_THAT(set, IsEmpty()); + EXPECT_THAT(set, SizeIs(0)); + EXPECT_THAT(set, UnorderedElementsAre()); + EXPECT_FALSE(set.contains(AttrKey())); + EXPECT_TRUE(set.find(AttrKey()) == set.end()); + EXPECT_EQ(set.erase(AttrKey()), 0); + + set.insert(AttrKey()); + + EXPECT_THAT(set, Not(IsEmpty())); + EXPECT_THAT(set, SizeIs(1)); + EXPECT_THAT(set, UnorderedElementsAre(AttrKey())); + EXPECT_TRUE(set.contains(AttrKey())); + EXPECT_TRUE(set.find(AttrKey()) == set.begin()); + EXPECT_EQ(set.erase(AttrKey()), 1); + EXPECT_THAT(set, IsEmpty()); + + set.insert(AttrKey()); + set.clear(); + EXPECT_THAT(set, IsEmpty()); + + set.insert(AttrKey()); + set.erase(AttrKey()); + EXPECT_THAT(set, IsEmpty()); +} + +TEST(AttrKey0MapTest, Works) { + AttrKeyHashMap, int> map; + + EXPECT_THAT(map, IsEmpty()); + EXPECT_THAT(map, SizeIs(0)); + EXPECT_THAT(map, UnorderedElementsAre()); + EXPECT_FALSE(map.contains(AttrKey())); + EXPECT_TRUE(map.find(AttrKey()) == map.end()); + EXPECT_EQ(map.erase(AttrKey()), 0); + + map.try_emplace(AttrKey(), 42); + + EXPECT_THAT(map, Not(IsEmpty())); + EXPECT_THAT(map, SizeIs(1)); + EXPECT_THAT(map, UnorderedElementsAre(Pair(AttrKey(), 42))); + EXPECT_EQ(map.begin()->first, AttrKey()); + EXPECT_EQ(map.begin()->second, 42); + EXPECT_TRUE(map.contains(AttrKey())); + EXPECT_TRUE(map.find(AttrKey()) == map.begin()); + EXPECT_EQ(map.erase(AttrKey()), 1); + EXPECT_THAT(map, IsEmpty()); + + map.insert({AttrKey(), 43}); + map.clear(); + EXPECT_THAT(map, IsEmpty()); + + map.try_emplace(AttrKey(), 43); + map.erase(AttrKey()); + EXPECT_THAT(map, IsEmpty()); + + map.try_emplace(AttrKey(), 43); + map.erase(map.begin()); + EXPECT_THAT(map, IsEmpty()); +} + +TEST(AttrKeyTest, FromRange) { + EXPECT_THAT((AttrKey<0>::FromRange({})), IsOkAndHolds(AttrKey())); + EXPECT_THAT((AttrKey<1>::FromRange({1})), IsOkAndHolds(AttrKey(1))); + EXPECT_THAT((AttrKey<2>::FromRange({1, 2})), IsOkAndHolds(AttrKey(1, 2))); + + EXPECT_THAT((AttrKey<0>::FromRange({1})), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT((AttrKey<1>::FromRange({})), + StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT((AttrKey<2>::FromRange({1})), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(AttrKeyTest, FromRangeSymmetric) { + using Key = AttrKey<3, ElementSymmetry<1, 2>>; + EXPECT_THAT((Key::FromRange({0, 1, 2})), IsOkAndHolds(Key(0, 1, 2))); + EXPECT_THAT((Key::FromRange({0, 2, 1})), IsOkAndHolds(Key(0, 1, 2))); + EXPECT_THAT((Key::FromRange({3, 1, 2})), IsOkAndHolds(Key(3, 1, 2))); + EXPECT_THAT((Key::FromRange({3, 2, 1})), IsOkAndHolds(Key(3, 1, 2))); +} + +TEST(AttrKeyTest, IsAttrKey) { + EXPECT_TRUE(is_attr_key_v>); + EXPECT_TRUE(is_attr_key_v>); + EXPECT_FALSE(is_attr_key_v); +} + +constexpr int kBenchmarkSize = 30; + +template +void BM_HashSet0(benchmark::State& state) { + SetT set; + for (const auto s : state) { + auto it = set.find(AttrKey()); + benchmark::DoNotOptimize(it); + } +} +BENCHMARK(BM_HashSet0>>); +BENCHMARK(BM_HashSet0>>); + +template +void BM_HashMap1(benchmark::State& state) { + absl::flat_hash_map map; + for (int i = 0; i < kBenchmarkSize * kBenchmarkSize; ++i) { + if (i % 2 > 0) { // Half of the lookups are hits. + map[T(i)] = i; + } + } + for (const auto s : state) { + for (int i = 0; i < kBenchmarkSize * kBenchmarkSize; ++i) { + auto it = map.find(T(i)); + benchmark::DoNotOptimize(it); + } + } +} +BENCHMARK(BM_HashMap1>); +BENCHMARK(BM_HashMap1); + +template +void BM_HashMap2(benchmark::State& state) { + absl::flat_hash_map map; + for (int i = 0; i < kBenchmarkSize; ++i) { + for (int j = 0; j < kBenchmarkSize; ++j) { + if ((i * kBenchmarkSize + j) % 2 > 0) { // Half of the lookups are hits. + map[T(i, j)] = i; + } + } + } + for (const auto s : state) { + for (int i = 0; i < kBenchmarkSize; ++i) { + for (int j = 0; j < kBenchmarkSize; ++j) { + auto it = map.find(T(i, j)); + benchmark::DoNotOptimize(it); + } + } + } +} +BENCHMARK(BM_HashMap2>); +BENCHMARK(BM_HashMap2>); + +template +void BM_SortAttrKeys(benchmark::State& state) { + const std::vector> keys = + MakeRandomAttrKeys(state.range(0), state.range(0)); + + for (const auto s : state) { + auto copy = keys; + absl::c_sort(copy); + benchmark::DoNotOptimize(copy); + } +} +BENCHMARK(BM_SortAttrKeys<1>)->Arg(100)->Arg(10000); +BENCHMARK(BM_SortAttrKeys<2>)->Arg(100)->Arg(10000); + +} // namespace +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/elemental/attr_storage.h b/ortools/math_opt/elemental/attr_storage.h new file mode 100644 index 0000000000..5299046567 --- /dev/null +++ b/ortools/math_opt/elemental/attr_storage.h @@ -0,0 +1,429 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "ortools/base/map_util.h" +#include "ortools/math_opt/elemental/attr_key.h" +#include "ortools/math_opt/elemental/symmetry.h" + +namespace operations_research::math_opt { + +namespace detail { + +// A non-default key set based on a vector. This is very efficient for +// insertions, reads, and slicing, but does not support deletions. +template +class DenseKeySet { + public: + // {Dense,Sparse}KeySet stores symmetric keys, symmetry is handled by + // `SlicingStorage`. + using Key = AttrKey; + + DenseKeySet() = default; + + size_t size() const { return key_set_.size(); } + + template + // requires std::invocable + void ForEach(F f) const { + for (const Key& key : key_set_) { + f(key); + } + } + + // Note: this does not check for duplicates. This is fine because inserting + // into this map is gated on inserting into the AttrStorage, which does check + // for duplicates. + void Insert(const Key& key) { key_set_.push_back(key); } + + auto begin() const { return key_set_.begin(); } + auto end() const { return key_set_.end(); } + + private: + std::vector key_set_; +}; + +// A non-default key set based on a hash set. Simple, but requires a hash lookup +// for each insertion and deletion. +template +class SparseKeySet { + public: + // {Dense,Sparse}KeySet stores symmetric keys, symmetry is handled by + // `SlicingStorage`. + using Key = AttrKey; + + explicit SparseKeySet(const DenseKeySet& dense_set) + : key_set_(dense_set.begin(), dense_set.end()) {} + + size_t size() const { return key_set_.size(); } + + template + // requires std::invocable + void ForEach(F f) const { + for (const Key& key : key_set_) { + f(key); + } + } + + void Erase(const Key& key) { key_set_.erase(key); } + void Insert(const Key& key) { key_set_.insert(key); } + + private: + absl::flat_hash_set key_set_; +}; + +// A non-default key set that switches between implementations +// opportunistically: It starts dense, and switches to sparse if there are +// deletions. +template +class KeySet { + public: + using Key = AttrKey; + + size_t size() const { + return std::visit([](const auto& impl) { return impl.size(); }, impl_); + } + + // We can't do begin/end because the iterator types are not the same. + template + // requires std::invocable + void ForEach(F f) const { + return std::visit( + [f = std::move(f)](const auto& impl) { + return impl.ForEach(std::move(f)); + }, + impl_); + } + + auto Erase(const Key& key) { return AsSparse().Erase(key); } + + void Insert(const Key& key) { + std::visit([&](auto& impl) { impl.Insert(key); }, impl_); + } + + private: + SparseKeySet& AsSparse() { + if (auto* sparse = std::get_if>(&impl_)) { + return *sparse; + } + // Switch to a sparse representation. + impl_ = SparseKeySet(std::get>(impl_)); + return std::get>(impl_); + } + + std::variant, SparseKeySet> impl_; +}; + +// When we have two or more dimensions, we need to store the nondefaults for +// each dimension to support slicing. +template +class SlicingSupport { + public: + using Key = AttrKey; + + void AddRowsAndColumns(const Key key) { + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + ForEachDimension([this, key]() { + if (MustInsertNondefault(key, Symmetry{})) { + key_nondefaults_[i][key[i]].Insert(key.template RemoveElement()); + } + }); + } + + // Requires key is currently stored with a non-default value. + void ClearRowsAndColumns(Key key) { + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + ForEachDimension([this, key]() { + const auto& key_elem = key[i]; + auto& nondefaults = key_nondefaults_[i]; + if (nondefaults[key_elem].size() == 1) { + nondefaults.erase(key_elem); + } else { + nondefaults[key_elem].Erase(key.template RemoveElement()); + } + }); + } + + void Clear() { + for (auto& key_nondefaults : key_nondefaults_) { + key_nondefaults.clear(); + } + } + + template + std::vector Slice(const int64_t key_elem) const { + return SliceImpl( + key_elem, Symmetry{}, + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + [key_elem](KeySetExpansion... expansions) { + std::vector slice((expansions.key_set.size() + ...)); + Key* out = slice.data(); + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + const auto append = [key_elem, &out]( + const KeySetExpansion& expansion) { + expansion.key_set.ForEach( + [key_elem, &out](const AttrKey other_elems) { + *out = other_elems.template AddElement(key_elem); + ++out; + }); + }; + (append(expansions), ...); + return slice; + }); + } + + template + int64_t GetSliceSize(const int64_t key_elem) const { + return SliceImpl(key_elem, Symmetry{}, [](const auto... expansions) { + return (expansions.key_set.size() + ...); + }); + } + + private: + // We store the nondefaults for a given id along a given dimension as a set of + // `AttrKey` (the current dimension is not stored). + using NonDefaultKeySet = KeySet; + // For each dimension, we store the nondefaults for each id. + using NonDefaultKeySetById = absl::flat_hash_map; + // We need one such set per dimension. + using NonDefaultsPerDimension = std::array; + + // Represents a NonDefaultKeySet to be expanded by inserting an element on + // dimension `i`. + template + struct KeySetExpansion { + KeySetExpansion(const NonDefaultsPerDimension& key_nondefaults, + int64_t key_elem) + : key_set(gtl::FindWithDefault(key_nondefaults[i], key_elem)) {} + const NonDefaultKeySet& key_set; + }; + + // A helper function that calls `F` `i` times with template arguments `n-1` to + // `0`. + template + static void ForEachDimension(const F& f) { + f.template operator()(); + if constexpr (i > 0) { + ForEachDimension(f); + } + } + + template + static bool MustInsertNondefault(const Key&, NoSymmetry) { + return true; + } + + template + static bool MustInsertNondefault(const Key& key, ElementSymmetry) { + // For attributes that are symmetric on `k` and `l`, elements on the + // diagonal need to be in only one of the nondefaults for `k` or `l` + // (otherwise they would be counted twice in `Slice()`). We arbitrarily pick + // `k`. + if constexpr (i == l) { + const bool is_diagonal = key[k] == key[l]; + return !is_diagonal; + } + return true; + } + + // `Fn` should be a functor that takes any number of `KeySetExpansion` + // arguments. + template + auto SliceImpl(const int64_t key_elem, NoSymmetry, const Fn& fn) const { + static_assert(n > 1); + return fn(KeySetExpansion(key_nondefaults_, key_elem)); + } + + template + auto SliceImpl(const int64_t key_elem, ElementSymmetry, + const Fn& fn) const { + static_assert(n > 1); + if constexpr (i != k && i != l) { + // This is a normal dimension, not a symmetric one. + return SliceImpl(key_elem, NoSymmetry(), fn); + } else { + // For symmetric dimensions, we need to look up the keys on both + // dimensions `l` and `k`. + return fn(KeySetExpansion(key_nondefaults_, key_elem), + KeySetExpansion(key_nondefaults_, key_elem)); + } + } + + NonDefaultsPerDimension key_nondefaults_; +}; + +// Without slicing we don't need to track anything. +template +struct SlicingSupport>> { + using Key = AttrKey; + + void AddRowsAndColumns(Key) {} + void ClearRowsAndColumns(Key) {} + void Clear() {} +}; + +} // namespace detail + +// Stores the value of an attribute keyed on n elements (e.g. +// linear_constraint_coefficient is a double valued attribute keyed first on +// LinearConstraint and then on Variable). +// +// Memory usage: +// Storing `k` elements with non-default values in a `AttrStorage` uses +// `sizeof(V) * (n^2 + 1) * k / load_factor` (where load_factor is the absl +// hash map load factor, typically 0.8), plus a small allocation overhead of +// `O(k)`. +template +class AttrStorage { + public: + using Key = AttrKey; + // If this no longer holds, we should sprinkle the code with `move`s and + // return `V`s by ref. + static_assert(std::is_trivially_copyable_v); + + // Generally avoid, provided to make working with std::array easier. + explicit AttrStorage() : AttrStorage({}) {} + + // The default value of the attribute is its value when the model is created + // (e.g. for linear_constraint_coefficient, 0.0). + explicit AttrStorage(const V default_value) : default_value_(default_value) {} + + AttrStorage(const AttrStorage&) = default; + AttrStorage& operator=(const AttrStorage&) = default; + AttrStorage(AttrStorage&&) = default; + AttrStorage& operator=(AttrStorage&&) = default; + + // Returns true if the attribute for `key` has a value different from its + // default. + bool IsNonDefault(const Key key) const { + return non_default_values_.contains(key); + } + + // Returns the previous value if value has changed, otherwise returns + // `std::nullopt`. + std::optional Set(const Key key, const V value) { + bool is_default = value == default_value_; + if (is_default) { + const auto it = non_default_values_.find(key); + if (it == non_default_values_.end()) { + return std::nullopt; + } + const V prev_value = it->second; + non_default_values_.erase(it); + slicing_support_.ClearRowsAndColumns(key); + return prev_value; + } + const auto [it, inserted] = non_default_values_.try_emplace(key, value); + if (inserted) { + slicing_support_.AddRowsAndColumns(key); + return default_value_; + } + // !is_default and !inserted + if (value == it->second) { + return std::nullopt; + } + return std::exchange(it->second, value); + } + + // Returns the value of the attribute for `key` (return the default value if + // the attribute value for `key` is unset). + V Get(const Key key) const { + return GetIfNonDefault(key).value_or(default_value_); + } + + // Returns the value of the attribute for `key`, or nullopt. + std::optional GetIfNonDefault(const Key key) const { + auto it = non_default_values_.find(key); + if (it == non_default_values_.end()) { + return std::nullopt; + } + return it->second; + } + + // Sets the value of the attribute for `key` to the default value. + void Erase(const Key key) { + if (non_default_values_.erase(key)) { + slicing_support_.ClearRowsAndColumns(key); + } + } + + // Returns the keys (ids pairs) the of the elements with a non-default value + // for this attribute. + std::vector NonDefaults() const { + std::vector result; + result.reserve(non_default_values_.size()); + for (const auto& [key, unused] : non_default_values_) { + result.push_back(key); + } + return result; + } + + // Returns the set of all keys `K` such that: + // - There exists `k_{0}..k_{n-1}` such that + // `K == AttrKey(k_{0}, ..., k_{i-1}, key_elem, k_{i+1}, ..., k_{n-1})`, and + // - `K` has a non-default value for this attribute. + template + std::vector Slice(const int64_t key_elem) const { + static_assert(n >= 1); + if constexpr (n == 1) { + return non_default_values_.contains(Key(key_elem)) + ? std::vector({Key(key_elem)}) + : std::vector(); + } else { + return slicing_support_.template Slice(key_elem); + } + } + + // Returns the size of the given slice: This is equivalent to + // `Slice(key_elem).size()`, but `O(1)`. + template + int64_t GetSliceSize(const int64_t key_elem) const { + static_assert(n >= 1); + if constexpr (n == 1) { + return non_default_values_.count(Key(key_elem)); + } else { + return slicing_support_.template GetSliceSize(key_elem); + } + } + + // Returns the number of keys (element pairs) with non-default values for this + // attribute. + int64_t num_non_defaults() const { return non_default_values_.size(); } + + // Restore all elements to their default value for this attribute. + void Clear() { + non_default_values_.clear(); + slicing_support_.Clear(); + } + + private: + V default_value_; + AttrKeyHashMap non_default_values_; + detail::SlicingSupport slicing_support_; +}; + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_ diff --git a/ortools/math_opt/elemental/attr_storage_test.cc b/ortools/math_opt/elemental/attr_storage_test.cc new file mode 100644 index 0000000000..875cfedb02 --- /dev/null +++ b/ortools/math_opt/elemental/attr_storage_test.cc @@ -0,0 +1,574 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/attr_storage.h" + +#include +#include +#include + +#include "benchmark/benchmark.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/math_opt/elemental/attr_key.h" +#include "ortools/math_opt/elemental/symmetry.h" + +namespace operations_research::math_opt { +namespace { + +using ::testing::IsEmpty; +using ::testing::Optional; +using ::testing::UnorderedElementsAre; + +TEST(Attr0StorageTest, EmptyGetters) { + const AttrStorage attr_storage(1.0); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0); + EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey())); +} + +TEST(Attr0StorageTest, SetDefaultToDefault) { + AttrStorage attr_storage(1.0); + + EXPECT_FALSE(attr_storage.Set(AttrKey(), 1.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0); + EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey())); +} + +TEST(Attr0StorageTest, SetDefaultToNonDefault) { + AttrStorage attr_storage(1.0); + + EXPECT_TRUE(attr_storage.Set(AttrKey(), 10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 10.0); + EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey())); +} + +TEST(Attr0StorageTest, SetNonDefaultToDefault) { + AttrStorage attr_storage(1.0); + EXPECT_TRUE(attr_storage.Set(AttrKey(), 10.0)); + + EXPECT_TRUE(attr_storage.Set(AttrKey(), 1.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0); + EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey())); +} + +TEST(Attr0StorageTest, SetNonDefaultToNonDefaultDifferent) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(), 10.0); + + EXPECT_TRUE(attr_storage.Set(AttrKey(), 20.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 20.0); + EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey())); +} + +TEST(Attr0StorageTest, SetNonDefaultToNonDefaultSame) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(), 10.0); + + EXPECT_FALSE(attr_storage.Set(AttrKey(), 10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 10.0); + EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey())); +} + +//////////////////////////////////////////////////////////////////////////////// +// Attr1Storage +//////////////////////////////////////////////////////////////////////////////// + +TEST(Attr1StorageTest, EmptyGetters) { + const AttrStorage attr_storage(1.0); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(0)), 1.0); + EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey(0))); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); + EXPECT_EQ(attr_storage.num_non_defaults(), 0); + EXPECT_THAT(attr_storage.Slice<0>(0), IsEmpty()); +} + +TEST(Attr1StorageTest, GettersNonEmpty) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2), 10.0); + attr_storage.Set(AttrKey(3), 11.0); + attr_storage.Set(AttrKey(5), 12.0); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 11.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(4)), 1.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5)), 12.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(6)), 1.0); + + EXPECT_THAT(attr_storage.NonDefaults(), + UnorderedElementsAre(AttrKey(2), AttrKey(3), AttrKey(5))); + EXPECT_EQ(attr_storage.num_non_defaults(), 3); + EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(AttrKey(3))); +} + +TEST(Attr1StorageTest, SetDefaultToDefault) { + AttrStorage attr_storage(1.0); + + EXPECT_FALSE(attr_storage.Set(AttrKey(2), 1.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); +} + +TEST(Attr1StorageTest, SetDefaultToNonDefault) { + AttrStorage attr_storage(1.0); + + EXPECT_TRUE(attr_storage.Set(AttrKey(2), 10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2))); +} + +TEST(Attr1StorageTest, SetNonDefaultToDefault) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2), 10.0); + + EXPECT_TRUE(attr_storage.Set(AttrKey(2), 1.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); +} + +TEST(Attr1StorageTest, SetNonDefaultToNonDefaultDifferent) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2), 5.0); + + EXPECT_TRUE(attr_storage.Set(AttrKey(2), 10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2))); +} + +TEST(Attr1StorageTest, SetNonDefaultToNonDefaultSame) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2), 10.0); + + EXPECT_FALSE(attr_storage.Set(AttrKey(2), 10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2))); +} + +TEST(Attr1StorageTest, Clear) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2), 10.0); + attr_storage.Set(AttrKey(3), 11.0); + + EXPECT_THAT(attr_storage.NonDefaults(), + UnorderedElementsAre(AttrKey(2), AttrKey(3))); + EXPECT_EQ(attr_storage.num_non_defaults(), 2); + + attr_storage.Clear(); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 1.0); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); + EXPECT_EQ(attr_storage.num_non_defaults(), 0); +} + +TEST(Attr1StorageTest, Erase) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2), 10.0); + attr_storage.Set(AttrKey(3), 11.0); + + attr_storage.Erase(AttrKey(2)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 11.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3))); + EXPECT_EQ(attr_storage.num_non_defaults(), 1); +} + +//////////////////////////////////////////////////////////////////////////////// +// Attr2Storage +//////////////////////////////////////////////////////////////////////////////// + +TEST(Attr2StorageTest, EmptyGetters) { + const AttrStorage attr_storage(1.0); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(0, 0)), 1.0); + EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey(0, 0))); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); + EXPECT_EQ(attr_storage.num_non_defaults(), 0); + EXPECT_THAT(attr_storage.Slice<1>(0), IsEmpty()); + EXPECT_THAT(attr_storage.GetSliceSize<1>(0), 0); + EXPECT_THAT(attr_storage.Slice<0>(0), IsEmpty()); + EXPECT_THAT(attr_storage.GetSliceSize<0>(0), 0); +} + +TEST(Attr2StorageTest, GettersNonEmpty) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 10.0); + attr_storage.Set(AttrKey(2, 5), 11.0); + attr_storage.Set(AttrKey(5, 5), 12.0); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 5)), 11.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5, 5)), 12.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5, 2)), 1.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 2)), 1.0); + + EXPECT_THAT( + attr_storage.NonDefaults(), + UnorderedElementsAre(AttrKey(2, 3), AttrKey(2, 5), AttrKey(5, 5))); + EXPECT_EQ(attr_storage.num_non_defaults(), 3); + EXPECT_THAT(attr_storage.Slice<0>(2), + UnorderedElementsAre(AttrKey(2, 3), AttrKey(2, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<0>(2), 2); + EXPECT_THAT(attr_storage.Slice<0>(3), IsEmpty()); + EXPECT_THAT(attr_storage.GetSliceSize<0>(3), 0); + EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<0>(5), 1); + + EXPECT_THAT(attr_storage.Slice<1>(2), IsEmpty()); + EXPECT_THAT(attr_storage.GetSliceSize<1>(2), 0); + EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(AttrKey(2, 3))); + EXPECT_THAT(attr_storage.GetSliceSize<1>(3), 1); + EXPECT_THAT(attr_storage.Slice<1>(5), + UnorderedElementsAre(AttrKey(2, 5), AttrKey(5, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<1>(5), 2); +} + +TEST(Attr2StorageTest, GettersNonEmptySymmetric) { + // Dim 0 + // | 0 1 2 3 4 5 + // --+------------------------ + // 0 | 0 + // D 1 | 0 0 + // i 2 | 0 0 0 + // m 3 | 0 10 0 0 + // 1 4 | 0 0 0 0 0 + // 5 | 0 11 0 0 0 12 + // + using Storage = AttrStorage>; + using Key = Storage::Key; + Storage attr_storage(1.0); + attr_storage.Set(Key(2, 3), 10.0); + attr_storage.Set(Key(2, 5), 123.0); + attr_storage.Set(Key(5, 2), 11.0); // Overwrites 123.0. + attr_storage.Set(Key(5, 5), 12.0); + + EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 3)), 10.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 5)), 11.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(Key(5, 5)), 12.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(Key(3, 2)), 10.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(Key(5, 2)), 11.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 2)), 1.0); + + EXPECT_THAT(attr_storage.NonDefaults(), + UnorderedElementsAre(Key(2, 3), Key(2, 5), Key(5, 5))); + EXPECT_EQ(attr_storage.num_non_defaults(), 3); + EXPECT_THAT(attr_storage.Slice<0>(2), + UnorderedElementsAre(Key(2, 3), Key(2, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<0>(2), 2); + EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(Key(2, 3))); + EXPECT_THAT(attr_storage.GetSliceSize<0>(3), 1); + EXPECT_THAT(attr_storage.Slice<0>(4), IsEmpty()); + EXPECT_THAT(attr_storage.GetSliceSize<0>(4), 0); + EXPECT_THAT(attr_storage.Slice<0>(5), + UnorderedElementsAre(Key(2, 5), Key(5, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<0>(5), 2); + + EXPECT_THAT(attr_storage.Slice<1>(2), + UnorderedElementsAre(Key(2, 3), Key(2, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<1>(2), 2); + EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(Key(2, 3))); + EXPECT_THAT(attr_storage.GetSliceSize<1>(3), 1); + EXPECT_THAT(attr_storage.Slice<1>(4), IsEmpty()); + EXPECT_THAT(attr_storage.GetSliceSize<1>(4), 0); + EXPECT_THAT(attr_storage.Slice<1>(5), + UnorderedElementsAre(Key(2, 5), Key(5, 5))); + EXPECT_THAT(attr_storage.GetSliceSize<1>(5), 2); +} + +TEST(Attr2StorageTest, SetDefaultToDefault) { + AttrStorage attr_storage(1.0); + + EXPECT_FALSE(attr_storage.Set(AttrKey(2, 3), 1.0).has_value()); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); +} + +TEST(Attr2StorageTest, SetDefaultToNonDefault) { + AttrStorage attr_storage(1.0); + + EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 10.0), Optional(1.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3))); +} + +TEST(Attr2StorageTest, SetNonDefaultToDefault) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 10.0); + + EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 1.0), Optional(10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); +} + +TEST(Attr2StorageTest, SetNonDefaultToNonDefaultDifferent) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 5.0); + + EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 10.0), Optional(5.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3))); +} + +TEST(Attr2StorageTest, SetNonDefaultToNonDefaultSame) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 10.0); + + EXPECT_FALSE(attr_storage.Set(AttrKey(2, 3), 10.0)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3))); +} + +TEST(Attr2StorageTest, Clear) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 10.0); + attr_storage.Set(AttrKey(3, 4), 11.0); + + EXPECT_THAT(attr_storage.NonDefaults(), + UnorderedElementsAre(AttrKey(2, 3), AttrKey(3, 4))); + EXPECT_EQ(attr_storage.num_non_defaults(), 2); + + attr_storage.Clear(); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3, 4)), 1.0); + EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty()); + EXPECT_EQ(attr_storage.num_non_defaults(), 0); +} + +TEST(Attr2StorageTest, Erase) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 10.0); + attr_storage.Set(AttrKey(3, 4), 11.0); + + attr_storage.Erase(AttrKey(2, 3)); + + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0); + EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3, 4)), 11.0); + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3, 4))); + EXPECT_EQ(attr_storage.num_non_defaults(), 1); +} + +TEST(Attr2StorageTest, EraseColumnLives) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(2, 3), 10.0); + attr_storage.Set(AttrKey(5, 3), 11.0); + + attr_storage.Erase(AttrKey(2, 3)); + + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(5, 3))); + EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 3))); + EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(AttrKey(5, 3))); + + // Insert again. + attr_storage.Set(AttrKey(2, 3), 12.0); + EXPECT_THAT(attr_storage.NonDefaults(), + UnorderedElementsAre(AttrKey(2, 3), AttrKey(5, 3))); + EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 3))); + EXPECT_THAT(attr_storage.Slice<1>(3), + UnorderedElementsAre(AttrKey(2, 3), AttrKey(5, 3))); +} + +TEST(Attr2StorageTest, EraseRowLives) { + AttrStorage attr_storage(1.0); + attr_storage.Set(AttrKey(3, 2), 10.0); + attr_storage.Set(AttrKey(3, 5), 11.0); + + attr_storage.Erase(AttrKey(3, 2)); + + EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3, 5))); + EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(AttrKey(3, 5))); + EXPECT_THAT(attr_storage.Slice<1>(5), UnorderedElementsAre(AttrKey(3, 5))); +} + +// Makes a set of `n` 1-dimensional keys. +std::vector> Make1DKeys(int n) { + std::vector> keys; + for (int64_t i = 0; i < n; ++i) { + keys.emplace_back(i); + } + return keys; +} + +// Makes a set of `n^2` 2-dimensional keys. +// NOTE: depending in `Symmetry` this might create duplicate keys. This is +// intentional, as we want to have the same number of keys to be able to compare +// the performance of different symmetries. +template +std::vector> Make2DKeys(int n) { + std::vector> keys; + for (int64_t i = 0; i < n; ++i) { + for (int64_t j = 0; j < n; ++j) { + keys.emplace_back(i, j); + } + } + return keys; +} + +// A functor that returns true every N calls, false otherwise. +template +struct TrueEvery { + int n = 0; + bool operator()() { + if (n == N) { + n = 0; + return true; + } + ++n; + return false; + } +}; + +void BM_Attr0StorageSet(benchmark::State& state) { + AttrStorage attr_storage(1.0); + + for (const auto s : state) { + attr_storage.Set(AttrKey(), 10.0); + benchmark::DoNotOptimize(attr_storage); + } +} +BENCHMARK(BM_Attr0StorageSet); + +void BM_Attr1StorageSet(benchmark::State& state) { + const int n = state.range(0); + + AttrStorage attr_storage(1.0); + const auto keys = Make1DKeys(n); + + for (const auto s : state) { + for (const auto& key : keys) { + attr_storage.Set(key, 10.0); + } + } +} +BENCHMARK(BM_Attr1StorageSet)->Arg(900); + +template +void BM_Attr2StorageSet(benchmark::State& state) { + const int n = state.range(0); + + const auto keys = Make2DKeys(n); + + std::optional> attr_storage(1.0); + for (const auto s : state) { + for (const auto& key : keys) { + attr_storage->Set(key, 10.0); + } + state.PauseTiming(); + attr_storage.emplace(1.0); + state.ResumeTiming(); + } +} +BENCHMARK(BM_Attr2StorageSet)->Arg(30); +BENCHMARK(BM_Attr2StorageSet>)->Arg(30); + +void BM_Attr0StorageGet(benchmark::State& state) { + AttrStorage attr_storage(1.0); + + for (const auto s : state) { + double v = attr_storage.Get(AttrKey()); + benchmark::DoNotOptimize(v); + } +} +BENCHMARK(BM_Attr0StorageGet); + +void BM_Attr1StorageGet(benchmark::State& state) { + const int n = state.range(0); + + AttrStorage attr_storage(1.0); + const auto keys = Make1DKeys(n); + // Insert half the keys. + TrueEvery<2> sample; + for (const auto& key : keys) { + if (sample()) { + attr_storage.Set(key, 10.0); + } + } + + for (const auto s : state) { + for (const auto& key : keys) { + double v = attr_storage.Get(key); + benchmark::DoNotOptimize(v); + } + } +} +BENCHMARK(BM_Attr1StorageGet)->Arg(900); + +template +void BM_Attr2StorageGet(benchmark::State& state) { + const int n = state.range(0); + + AttrStorage attr_storage(1.0); + const auto keys = Make2DKeys(n); + // Insert half the keys. + TrueEvery<2> sample; + for (const auto& key : keys) { + if (sample()) { + attr_storage.Set(key, 10.0); + } + } + + for (const auto s : state) { + for (const auto& key : keys) { + double v = attr_storage.Get(key); + benchmark::DoNotOptimize(v); + } + } +} +BENCHMARK(BM_Attr2StorageGet)->Arg(30); +BENCHMARK(BM_Attr2StorageGet>)->Arg(30); + +template +void BM_Attr2StorageSlice(benchmark::State& state) { + const int n = state.range(0); + + AttrStorage attr_storage(1.0); + const auto keys = Make2DKeys(n); + // Insert 5% of the keys. + TrueEvery<20> sample; + for (const auto& key : keys) { + if (sample()) { + attr_storage.Set(key, 10.0); + } + } + + for (const auto s : state) { + for (int key_id = 0; key_id < n; ++key_id) { + auto slice0 = attr_storage.template Slice<0>(key_id); + auto slice1 = attr_storage.template Slice<1>(key_id); + benchmark::DoNotOptimize(slice0); + benchmark::DoNotOptimize(slice1); + } + } +} +BENCHMARK(BM_Attr2StorageSlice)->Arg(30); +BENCHMARK(BM_Attr2StorageSlice>)->Arg(30); + +} // namespace +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/elemental/attributes.h b/ortools/math_opt/elemental/attributes.h new file mode 100644 index 0000000000..7c628305a4 --- /dev/null +++ b/ortools/math_opt/elemental/attributes.h @@ -0,0 +1,349 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "ortools/base/array.h" +#include "ortools/math_opt/elemental/arrays.h" +#include "ortools/math_opt/elemental/elements.h" +#include "ortools/math_opt/elemental/symmetry.h" + +namespace operations_research::math_opt { + +// A base class for all attribute type descriptors. +// `ValueTypeT` is the attribute value type, and `n` is the number of key +// elements (e.g. `Double2` attribute has `ValueType` == `double` and `n` == 2). +// This uses +// [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) in +// `Impl` to deduce common descriptor properties from `Impl`. `Impl` must +// inherit from `AttrTypeDescriptor` and define the following entities: +// - `static constexpr absl::string_view kName`: The name of the attribute +// type. +// - `enum class AttrType`: The attribute type, with `k` enumerators +// corresponding to attributes for this type. Enumerators must be numbered +// `0..(k-1)` (a good way to do this is to leave them unnumbered). +// - `std::array kAttrDescriptors`: A descriptor for each +// of the `k` attributes for this type. +template +struct AttrTypeDescriptor { + // The type of attribute values (e.g. `bool`, `int64_t`, `double`). + using ValueType = ValueTypeT; + + // The number of key elements. + static constexpr int kNumKeyElements = n; + + // The key symmetry. For example, this can be used to enforce that + // quadratic objective coefficients are the same for `(i, j)` and `(j, i)` + // (see `kObjQuadCoef` below). + using Symmetry = SymmetryT; + + // A descriptor of an attribute of this attribute type. + // E.g., this could describe the attribute `DoubleAttr1::kVarLb`. + struct AttrDescriptor { + // The name of the attribute value. + absl::string_view name; + // The default value. + ValueType default_value; + // The types of the `n` key elements. + std::array key_types; + }; + + // Returns the number of attributes of this attribute type. + static constexpr int NumAttrs() { return Impl::kAttrDescriptors.size(); } + + // Returns an array with all attributes of this attribute type. + static constexpr auto Enumerate() { + std::array result; + for (int i = 0; i < NumAttrs(); ++i) { + result[i] = {static_cast(i)}; + } + return result; + } +}; + +struct BoolAttr0TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "BoolAttr0"; + + enum class AttrType { kMaximize }; + + static constexpr auto kAttrDescriptors = gtl::to_array( + {{.name = "maximize", .default_value = false, .key_types = {}}}); +}; + +struct BoolAttr1TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "BoolAttr1"; + + enum class AttrType { + kVarInteger, + kAuxObjMaximize, + kIndConActivateOnZero, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array( + {{.name = "variable_integer", + .default_value = false, + .key_types = {ElementType::kVariable}}, + {.name = "auxiliary_objective_maximize", + .default_value = false, + .key_types = {ElementType::kAuxiliaryObjective}}, + {.name = "indicator_constraint_activate_on_zero", + .default_value = false, + .key_types = {ElementType::kIndicatorConstraint}}}); +}; + +struct IntAttr0TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "IntAttr0"; + + enum class AttrType { + kObjPriority, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "objective_priority", .default_value = 0, .key_types = {}}, + }); +}; + +struct IntAttr1TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "IntAttr1"; + + enum class AttrType { + kAuxObjPriority, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "auxiliary_objective_priority", + .default_value = 0, + .key_types = {ElementType::kAuxiliaryObjective}}, + }); +}; + +struct DoubleAttr0TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "DoubleAttr0"; + + enum class AttrType { kObjOffset }; + + static constexpr auto kAttrDescriptors = gtl::to_array( + {{.name = "objective_offset", .default_value = 0.0, .key_types = {}}}); +}; + +struct DoubleAttr1TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "DoubleAttr1"; + + enum class AttrType { + kVarLb, + kVarUb, + kObjLinCoef, + kLinConLb, + kLinConUb, + kAuxObjOffset, + kQuadConLb, + kQuadConUb, + kIndConLb, + kIndConUb, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "variable_lower_bound", + .default_value = -std::numeric_limits::infinity(), + .key_types = {ElementType::kVariable}}, + {.name = "variable_upper_bound", + .default_value = std::numeric_limits::infinity(), + .key_types = {ElementType::kVariable}}, + {.name = "objective_linear_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kVariable}}, + {.name = "linear_constraint_lower_bound", + .default_value = -std::numeric_limits::infinity(), + .key_types = {ElementType::kLinearConstraint}}, + {.name = "linear_constraint_upper_bound", + .default_value = std::numeric_limits::infinity(), + .key_types = {ElementType::kLinearConstraint}}, + {.name = "auxiliary_objective_offset", + .default_value = 0.0, + .key_types = {ElementType::kAuxiliaryObjective}}, + {.name = "quadratic_constraint_lower_bound", + .default_value = -std::numeric_limits::infinity(), + .key_types = {ElementType::kQuadraticConstraint}}, + {.name = "quadratic_constraint_upper_bound", + .default_value = std::numeric_limits::infinity(), + .key_types = {ElementType::kQuadraticConstraint}}, + {.name = "indicator_constraint_lower_bound", + .default_value = -std::numeric_limits::infinity(), + .key_types = {ElementType::kIndicatorConstraint}}, + {.name = "indicator_constraint_upper_bound", + .default_value = std::numeric_limits::infinity(), + .key_types = {ElementType::kIndicatorConstraint}}, + }); +}; + +struct DoubleAttr2TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "DoubleAttr2"; + + enum class AttrType { + kLinConCoef, + kAuxObjLinCoef, + kQuadConLinCoef, + kIndConLinCoef + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "linear_constraint_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kLinearConstraint, ElementType::kVariable}}, + {.name = "auxiliary_objective_linear_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kAuxiliaryObjective, ElementType::kVariable}}, + {.name = "quadratic_constraint_linear_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kQuadraticConstraint, + ElementType::kVariable}}, + {.name = "indicator_constraint_linear_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kIndicatorConstraint, + ElementType::kVariable}}, + }); +}; + +struct SymmetricDoubleAttr2TypeDescriptor + : public AttrTypeDescriptor, + SymmetricDoubleAttr2TypeDescriptor> { + static constexpr absl::string_view kName = "SymmetricDoubleAttr2"; + + enum class AttrType { + kObjQuadCoef, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "objective_quadratic_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kVariable, ElementType::kVariable}}, + }); +}; + +// Note: For this type, we pick the symmetric elements to be the last 2 elements +// of the key (index 1 and 2). +struct SymmetricDoubleAttr3TypeDescriptor + : public AttrTypeDescriptor, + SymmetricDoubleAttr3TypeDescriptor> { + static constexpr absl::string_view kName = "SymmetricDoubleAttr3"; + + enum class AttrType { + kQuadConQuadCoef, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "quadratic_constraint_quadratic_coefficient", + .default_value = 0.0, + .key_types = {ElementType::kQuadraticConstraint, ElementType::kVariable, + ElementType::kVariable}}, + }); +}; + +struct VariableAttr1TypeDescriptor + : public AttrTypeDescriptor { + static constexpr absl::string_view kName = "VariableAttr1"; + + enum class AttrType { + kIndConIndicator, + }; + + static constexpr auto kAttrDescriptors = gtl::to_array({ + {.name = "indicator_constraint_indicator", + .default_value = VariableId(), + .key_types = {ElementType::kIndicatorConstraint}}, + }); +}; + +// The list of all available attribute descriptors. This is typically +// manipulated using the `AllAttrs` helper in `derived_data.h`. +using AllAttrTypeDescriptors = + std::tuple; + +// Aliases for types. +using BoolAttr0 = BoolAttr0TypeDescriptor::AttrType; +using BoolAttr1 = BoolAttr1TypeDescriptor::AttrType; +using IntAttr0 = IntAttr0TypeDescriptor::AttrType; +using IntAttr1 = IntAttr1TypeDescriptor::AttrType; +using DoubleAttr0 = DoubleAttr0TypeDescriptor::AttrType; +using DoubleAttr1 = DoubleAttr1TypeDescriptor::AttrType; +using DoubleAttr2 = DoubleAttr2TypeDescriptor::AttrType; +using SymmetricDoubleAttr2 = SymmetricDoubleAttr2TypeDescriptor::AttrType; +using SymmetricDoubleAttr3 = SymmetricDoubleAttr3TypeDescriptor::AttrType; +using VariableAttr1 = VariableAttr1TypeDescriptor::AttrType; + +// Returns the index of `AttrT` in `AllAttrTypes` if `AttrT` is an attribute +// type, -1 otherwise. +template +static constexpr int GetIndexIfAttr() { + using Tuple = AllAttrTypeDescriptors; + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + return ApplyOnIndexRange>([]() { + return ((std::is_same_v>, + typename std::tuple_element_t::AttrType> + ? (i + 1) + : 0) + + ... + -1); + }); +} + +template () >= 0)>> +absl::string_view ToString(const AttrT attr) { + using Descriptor = + std::tuple_element_t(), AllAttrTypeDescriptors>; + const int attr_index = static_cast(attr); + return Descriptor::kAttrDescriptors[attr_index].name; +} + +template () >= 0)>> +void AbslStringify(Sink& sink, const AttrT attr_type) { + sink.Append(ToString(attr_type)); +} + +template +std::enable_if_t<(GetIndexIfAttr() >= 0), std::ostream&> operator<<( + std::ostream& ostr, AttrT attr) { + ostr << ToString(attr); + return ostr; +} + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_ diff --git a/ortools/math_opt/elemental/attributes_test.cc b/ortools/math_opt/elemental/attributes_test.cc new file mode 100644 index 0000000000..4e82f856df --- /dev/null +++ b/ortools/math_opt/elemental/attributes_test.cc @@ -0,0 +1,58 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/attributes.h" + +#include "absl/strings/str_cat.h" +#include "gtest/gtest.h" +#include "ortools/math_opt/elemental/arrays.h" +#include "ortools/math_opt/testing/stream.h" + +namespace operations_research::math_opt { +namespace { + +TEST(ToStringTests, EachTypeCanConvert) { + EXPECT_EQ(ToString(BoolAttr0::kMaximize), "maximize"); + EXPECT_EQ(ToString(BoolAttr1::kVarInteger), "variable_integer"); + EXPECT_EQ(ToString(IntAttr0::kObjPriority), "objective_priority"); + EXPECT_EQ(ToString(IntAttr1::kAuxObjPriority), + "auxiliary_objective_priority"); + EXPECT_EQ(ToString(DoubleAttr0::kObjOffset), "objective_offset"); + EXPECT_EQ(ToString(DoubleAttr1::kVarLb), "variable_lower_bound"); + EXPECT_EQ(ToString(DoubleAttr2::kLinConCoef), + "linear_constraint_coefficient"); + EXPECT_EQ(ToString(SymmetricDoubleAttr2::kObjQuadCoef), + "objective_quadratic_coefficient"); + EXPECT_EQ(ToString(SymmetricDoubleAttr3::kQuadConQuadCoef), + "quadratic_constraint_quadratic_coefficient"); + // Now check that absl::Stringify wraps ToString() + EXPECT_EQ(absl::StrCat(BoolAttr0::kMaximize), "maximize"); + // Now check that << wraps ToString() + EXPECT_EQ(StreamToString(BoolAttr0::kMaximize), "maximize"); +} + +// Validate that for all symmetric attribute types, the symmetry is consistent +// with element types. +TEST(SymmetryTest, AllSymmetricTypesAreCorrect) { + ForEach( + // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) + [](const Descriptor&) { + for (const auto& attr : Descriptor::kAttrDescriptors) { + Descriptor::Symmetry::CheckElementTypes(attr.key_types); + } + }, + AllAttrTypeDescriptors{}); +} + +} // namespace +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/elemental/codegen/BUILD.bazel b/ortools/math_opt/elemental/codegen/BUILD.bazel new file mode 100644 index 0000000000..bf6d9c59b4 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/BUILD.bazel @@ -0,0 +1,81 @@ +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_binary.bzl", "cc_binary") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +cc_library( + name = "gen", + srcs = ["gen.cc"], + hdrs = ["gen.h"], + deps = [ + "//ortools/math_opt/elemental:arrays", + "//ortools/math_opt/elemental:attributes", + "//ortools/math_opt/elemental:elements", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:string_view", + "@abseil-cpp//absl/types:span", + ], +) + +cc_library( + name = "gen_c", + srcs = ["gen_c.cc"], + hdrs = ["gen_c.h"], + deps = [ + ":gen", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/types:span", + ], +) + +cc_library( + name = "testing", + testonly = 1, + hdrs = ["testing.h"], + deps = [":gen"], +) + +cc_library( + name = "gen_python", + srcs = ["gen_python.cc"], + hdrs = ["gen_python.h"], + deps = [ + ":gen", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/strings:string_view", + "@abseil-cpp//absl/types:span", + ], +) + +cc_binary( + name = "codegen", + srcs = ["codegen.cc"], + visibility = [ + "//ortools/math_opt/elemental/c_api:__subpackages__", + "//ortools/math_opt/elemental/python:__subpackages__", + ], + deps = [ + ":gen", + ":gen_c", + ":gen_python", + "//ortools/base", + "@abseil-cpp//absl/flags:flag", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings:string_view", + ], +) diff --git a/ortools/math_opt/elemental/codegen/codegen.cc b/ortools/math_opt/elemental/codegen/codegen.cc new file mode 100644 index 0000000000..e0aeacabf3 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/codegen.cc @@ -0,0 +1,52 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "ortools/base/init_google.h" +#include "ortools/math_opt/elemental/codegen/gen.h" +#include "ortools/math_opt/elemental/codegen/gen_c.h" +#include "ortools/math_opt/elemental/codegen/gen_python.h" + +ABSL_FLAG(std::string, binding_type, "", "The binding type to generate."); + +namespace operations_research::math_opt::codegen { + +namespace { +void Main() { + const std::string binding_type = absl::GetFlag(FLAGS_binding_type); + if (binding_type == "c99_h") { + std::cout << C99Declarations()->GenerateCode(); + } else if (binding_type == "c99_cc") { + std::cout << C99Definitions()->GenerateCode(); + } else if (binding_type == "python_enums") { + std::cout << PythonEnums()->GenerateCode(); + } else { + LOG(FATAL) << "unknown binding type: '" << binding_type << "'"; + } +} + +} // namespace +} // namespace operations_research::math_opt::codegen + +int main(int argc, char** argv) { + InitGoogle(argv[0], &argc, &argv, /*remove_flags=*/true); + operations_research::math_opt::codegen::Main(); + return 0; +} diff --git a/ortools/math_opt/elemental/codegen/gen.cc b/ortools/math_opt/elemental/codegen/gen.cc new file mode 100644 index 0000000000..1fc55572ad --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen.cc @@ -0,0 +1,158 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/codegen/gen.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/math_opt/elemental/arrays.h" +#include "ortools/math_opt/elemental/attributes.h" +#include "ortools/math_opt/elemental/elements.h" + +namespace operations_research::math_opt::codegen { + +namespace { + +class NamedType : public Type { + public: + explicit NamedType(std::string name) : name_(std::move(name)) {} + + void Print(absl::string_view, std::string* out) const final { + absl::StrAppend(out, name_); + } + + private: + std::string name_; +}; + +class PointerType : public Type { + public: + explicit PointerType(std::shared_ptr pointee) + : pointee_(std::move(pointee)) {} + + void Print(absl::string_view attr_value_type, std::string* out) const final { + pointee_->Print(attr_value_type, out); + absl::StrAppend(out, "*"); + } + + private: + std::shared_ptr pointee_; +}; + +class AttrValueTypeType : public Type { + public: + void Print(absl::string_view attr_value_type, std::string* out) const final { + absl::StrAppend(out, attr_value_type); + } +}; + +} // namespace + +std::shared_ptr Type::Named(std::string name) { + return std::make_shared(std::move(name)); +} + +std::shared_ptr Type::Pointer(std::shared_ptr pointee) { + return std::make_shared(std::move(pointee)); +} + +std::shared_ptr Type::AttrValueType() { + return std::make_shared(); +} + +Type::~Type() = default; + +CodegenAttrTypeDescriptor::ValueType GetValueType(bool) { + return CodegenAttrTypeDescriptor::ValueType::kBool; +} + +CodegenAttrTypeDescriptor::ValueType GetValueType(int64_t) { + return CodegenAttrTypeDescriptor::ValueType::kInt64; +} + +CodegenAttrTypeDescriptor::ValueType GetValueType(double) { + return CodegenAttrTypeDescriptor::ValueType::kDouble; +} + +template +CodegenAttrTypeDescriptor::ValueType GetValueType(ElementId) { + // Element ids are untyped in wrapped APIs. + return CodegenAttrTypeDescriptor::ValueType::kInt64; +} + +template +CodegenAttrTypeDescriptor MakeAttrTypeDescriptor() { + CodegenAttrTypeDescriptor descriptor; + descriptor.value_type = GetValueType(typename Descriptor::ValueType{}); + descriptor.name = Descriptor::kName; + descriptor.num_key_elements = Descriptor::kNumKeyElements; + descriptor.symmetry = Descriptor::Symmetry::GetName(); + + descriptor.attribute_names.reserve(Descriptor::NumAttrs()); + for (const auto& attr_descriptor : Descriptor::kAttrDescriptors) { + descriptor.attribute_names.push_back(attr_descriptor.name); + } + return descriptor; +} + +constexpr absl::string_view kOpNames[static_cast(AttrOp::kNumOps)] = { + "Get", "Set", "IsNonDefault", "NumNonDefaults", "GetNonDefaults"}; + +void CodeGenerator::EmitAttrType(const CodegenAttrTypeDescriptor& descriptor, + std::string* out) const { + StartAttrType(descriptor, out); + for (int op = 0; op < kNumAttrOps; ++op) { + const AttrOpFunctionInfo& op_info = attr_op_function_infos_[op]; + EmitAttrOp(kOpNames[op], descriptor, op_info, out); + } +} + +void CodeGenerator::EmitAttributes( + absl::Span descriptors, + std::string* out) const { + for (const auto& descriptor : descriptors) { + StartAttrType(descriptor, out); + for (int i = 0; i < kNumAttrOps; ++i) { + EmitAttrOp(kOpNames[i], descriptor, attr_op_function_infos_[i], out); + } + } +} + +std::string CodeGenerator::GenerateCode() const { + std::string out; + EmitHeader(&out); + + // Generate elements. + EmitElements(kElementNames, &out); + + // Generate attributes. + std::vector attr_type_descriptors; + ForEach( + [&attr_type_descriptors](auto type_descriptor) { + attr_type_descriptors.push_back( + MakeAttrTypeDescriptor()); + }, + AllAttrTypeDescriptors{}); + EmitAttributes(attr_type_descriptors, &out); + + return out; +} + +} // namespace operations_research::math_opt::codegen diff --git a/ortools/math_opt/elemental/codegen/gen.h b/ortools/math_opt/elemental/codegen/gen.h new file mode 100644 index 0000000000..62faf50a14 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen.h @@ -0,0 +1,143 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Language-agnostic utilities for `Elemental` codegen. +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace operations_research::math_opt::codegen { + +// The list of attribute operations supported by `Elemental`. +enum class AttrOp { + kGet, + kSet, + kIsNonDefault, + kNumNonDefaults, + kGetNonDefaults, + // Do not use. + kNumOps, +}; + +static constexpr int kNumAttrOps = static_cast(AttrOp::kNumOps); + +// A struct to represent an attribute type descriptor during codegen. +struct CodegenAttrTypeDescriptor { + // The attribute type name. + absl::string_view name; + // The value type of the attribute. + enum class ValueType { + kBool, + kInt64, + kDouble, + }; + ValueType value_type; + // The number of key elements. + int num_key_elements; + // The key symmetry. + std::string symmetry; + + // The names of the attributes of this type. + std::vector attribute_names; +}; + +// Representations for types. +class Type { + public: + // A named type, e.g. "double". + static std::shared_ptr Named(std::string name); + // A pointer type. + static std::shared_ptr Pointer(std::shared_ptr pointee); + // A placeholder for the attribute value type, which is yet unknown when types + // are defined. This gets replaced by `attr_value_type` when calling `Print`. + static std::shared_ptr AttrValueType(); + + virtual ~Type(); + + // Prints the type to `out`, replacing `AttrValueType` placeholders with + // `attr_value_type`. + virtual void Print(absl::string_view attr_value_type, + std::string* out) const = 0; +}; + +// Information about how to codegen a given `AttrOp` in a given language. +struct AttrOpFunctionInfo { + // The return type of the function. + std::shared_ptr return_type; + + // If true, the function has an `AttrKey` parameter. + bool has_key_parameter; + + // Extra parameters (e.g. {"double", "value"} for `Set` operations). + struct ExtraParameter { + std::shared_ptr type; + std::string name; + }; + std::vector extra_parameters; +}; + +using AttrOpFunctionInfos = std::array; + +// The code generator interface. +class CodeGenerator { + public: + explicit CodeGenerator(const AttrOpFunctionInfos* attr_op_function_infos) + : attr_op_function_infos_(*attr_op_function_infos) {} + + virtual ~CodeGenerator() = default; + + // Generates code. + std::string GenerateCode() const; + + // Emits the header for the generated code. + virtual void EmitHeader(std::string* out) const {} + + // Emits code for elements. + virtual void EmitElements(absl::Span elements, + std::string* out) const {} + + // Emits code for attributes. By default, this iterates attributes and for + // each attribute: + // - calls `StartAttrType`, and + // - calls `EmitAttrOp` for each operation. + virtual void EmitAttributes( + absl::Span descriptors, + std::string* out) const; + + // Called before generating code for an attribute type. + virtual void StartAttrType(const CodegenAttrTypeDescriptor& descriptor, + std::string* out) const {} + + // Emits code for operation `info` for attribute described by `descriptor`. + virtual void EmitAttrOp(absl::string_view op_name, + const CodegenAttrTypeDescriptor& descriptor, + const AttrOpFunctionInfo& info, + std::string* out) const {} + + private: + void EmitAttrType(const CodegenAttrTypeDescriptor& descriptor, + std::string* out) const; + + const AttrOpFunctionInfos& attr_op_function_infos_; +}; + +} // namespace operations_research::math_opt::codegen + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_ diff --git a/ortools/math_opt/elemental/codegen/gen_c.cc b/ortools/math_opt/elemental/codegen/gen_c.cc new file mode 100644 index 0000000000..3b7b3768ad --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_c.cc @@ -0,0 +1,245 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/codegen/gen_c.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/math_opt/elemental/codegen/gen.h" + +namespace operations_research::math_opt::codegen { + +namespace { + +// A helper to generate parameters to pass `n` key element indices, e.g: +// ", int64_t key_0, int64_t key_1" (parameters) +void AddKeyParams(int n, std::string* out) { + for (int i = 0; i < n; ++i) { + absl::StrAppend(out, ", int64_t key_", i); + } +} + +// A helper to generate an AttrKey argument to pass `n` key element indices, +// e.g: "AttrKey<2, NoSymmetry>(key_0, key_1)". +void AddAttrKeyArg(int n, absl::string_view symmetry, std::string* out) { + absl::StrAppendFormat(out, ", AttrKey<%i, %s>(", n, symmetry); + for (int i = 0; i < n; ++i) { + if (i != 0) { + absl::StrAppend(out, ", "); + } + absl::StrAppend(out, "key_", i); + } + absl::StrAppend(out, ")"); +} + +// Returns the C99 name for the given type. +absl::string_view GetCTypeName( + CodegenAttrTypeDescriptor::ValueType value_type) { + switch (value_type) { + case CodegenAttrTypeDescriptor::ValueType::kBool: + return "_Bool"; + case CodegenAttrTypeDescriptor::ValueType::kInt64: + return "int64_t"; + case CodegenAttrTypeDescriptor::ValueType::kDouble: + return "double"; + } +} + +// Turns an element/attribute name (e.g. "some_name") into a camel case name +// (e.g. "SomeName"). +std::string NameToCamelCase(absl::string_view attr_name) { + std::string result; + result.reserve(attr_name.size()); + CHECK(!attr_name.empty()); + const char first = attr_name[0]; + CHECK(absl::ascii_isalpha(first) && absl::ascii_islower(first)) + << "invalid attr name: " << attr_name; + result.push_back(absl::ascii_toupper(first)); + for (int i = 1; i < attr_name.size(); ++i) { + const char c = attr_name[i]; + if (c == '_') { + ++i; + CHECK(i < attr_name.size()) << "invalid attr name: " << attr_name; + const char next_c = attr_name[i]; + CHECK(absl::ascii_isalnum(next_c)) << "invalid attr name: " << attr_name; + result.push_back(absl::ascii_toupper(next_c)); + } else { + CHECK(absl::ascii_isalnum(c)) << "invalid attr name: " << attr_name; + CHECK(absl::ascii_islower(c)) << "invalid attr name: " << attr_name; + result.push_back(c); + } + } + return result; +} + +// Returns the type of the C status. +std::shared_ptr GetStatusType() { return Type::Named("int"); } + +const AttrOpFunctionInfos* GetC99FunctionInfos() { + static const auto* const kResult = new AttrOpFunctionInfos({ + // Get. + AttrOpFunctionInfo{ + .return_type = GetStatusType(), + .has_key_parameter = true, + .extra_parameters = {{.type = Type::Pointer(Type::AttrValueType()), + .name = "value"}}}, + // Set. + AttrOpFunctionInfo{.return_type = GetStatusType(), + .has_key_parameter = true, + .extra_parameters = {{.type = Type::AttrValueType(), + .name = "value"}}}, + // IsNonDefault. + AttrOpFunctionInfo{ + .return_type = GetStatusType(), + .has_key_parameter = true, + .extra_parameters = {{.type = Type::Pointer(Type::Named("_Bool")), + .name = "out_is_non_default"}}}, + // NumNonDefaults. + AttrOpFunctionInfo{ + .return_type = GetStatusType(), + .has_key_parameter = false, + .extra_parameters = {{.type = Type::Pointer(Type::Named("int64_t")), + .name = "out_num_non_defaults"}}}, + // GetNonDefaults. + AttrOpFunctionInfo{ + .return_type = GetStatusType(), + .has_key_parameter = false, + .extra_parameters = + { + {.type = Type::Pointer(Type::Named("int64_t")), + .name = "out_num_non_defaults"}, + {.type = Type::Pointer(Type::Pointer(Type::Named("int64_t"))), + .name = "out_non_defaults"}, + }}, + }); + return kResult; +} + +class C99CodeGeneratorBase : public CodeGenerator { + public: + using CodeGenerator::CodeGenerator; + + void EmitHeader(std::string* out) const final { + absl::StrAppend(out, R"( +/* DO NOT EDIT: This file is autogenerated. */ +#ifndef MATHOPTH_GENERATED +#error "this file is intended to be included, do not use directly" +#endif +)"); + } +}; + +// Emits the prototype for a function. +void EmitPrototype(absl::string_view op_name, + const CodegenAttrTypeDescriptor& descriptor, + const AttrOpFunctionInfo& info, std::string* out) { + absl::string_view attr_value_type = GetCTypeName(descriptor.value_type); + // Adds the return type, function name and common parameters. + info.return_type->Print(attr_value_type, out); + absl::StrAppendFormat(out, + " MathOpt%s%s(struct " + "MathOptElemental* e, int attr", + descriptor.name, op_name); + // Add the key. + if (info.has_key_parameter) { + AddKeyParams(descriptor.num_key_elements, out); + } + // Add extra parameters. + for (const auto& extra_param : info.extra_parameters) { + absl::StrAppend(out, ", "); + extra_param.type->Print(attr_value_type, out); + absl::StrAppend(out, " ", extra_param.name); + } + // Finish prototype. + absl::StrAppend(out, ")"); +} + +class C99DeclarationsGenerator : public C99CodeGeneratorBase { + public: + C99DeclarationsGenerator() : C99CodeGeneratorBase(GetC99FunctionInfos()) {} + + void EmitElements(absl::Span elements, + std::string* out) const override { + // Generate an enum for the elements. + absl::StrAppend(out, + "// The type of an element in the model.\n" + "enum MathOptElementType {\n"); + for (const auto& element_name : elements) { + absl::StrAppendFormat(out, " kMathOpt%s,\n", + NameToCamelCase(element_name)); + } + absl::StrAppend(out, "};\n\n"); + } + + void EmitAttrOp(absl::string_view op_name, + const CodegenAttrTypeDescriptor& descriptor, + const AttrOpFunctionInfo& info, + std::string* out) const override { + // Just emit a prototype. + EmitPrototype(op_name, descriptor, info, out); + absl::StrAppend(out, ";\n"); + } + + void StartAttrType(const CodegenAttrTypeDescriptor& descriptor, + std::string* out) const override { + // Generate an enum for the attribute type. + absl::StrAppendFormat(out, "typedef enum {\n"); + for (absl::string_view attr_name : descriptor.attribute_names) { + absl::StrAppendFormat(out, " kMathOpt%s%s,\n", descriptor.name, + NameToCamelCase(attr_name)); + } + absl::StrAppendFormat(out, "} MathOpt%s;\n", descriptor.name); + } +}; + +class C99DefinitionsGenerator : public C99CodeGeneratorBase { + public: + C99DefinitionsGenerator() : C99CodeGeneratorBase(GetC99FunctionInfos()) {} + + void EmitAttrOp(absl::string_view op_name, + const CodegenAttrTypeDescriptor& descriptor, + const AttrOpFunctionInfo& info, + std::string* out) const override { + EmitPrototype(op_name, descriptor, info, out); + // Emit a call to the wrapper (e.g. `CAttrOp::Op`). + absl::StrAppendFormat(out, " {\n return CAttrOp<%s>::%s(e, attr", + descriptor.name, op_name); + // Add the key argument. + if (info.has_key_parameter) { + AddAttrKeyArg(descriptor.num_key_elements, descriptor.symmetry, out); + } + // Add extra parameter arguments. + for (const auto& extra_param : info.extra_parameters) { + absl::StrAppend(out, ", ", extra_param.name); + } + absl::StrAppend(out, ");\n}\n"); + } +}; +} // namespace + +std::unique_ptr C99Declarations() { + return std::make_unique(); +} + +std::unique_ptr C99Definitions() { + return std::make_unique(); +} + +} // namespace operations_research::math_opt::codegen diff --git a/ortools/math_opt/elemental/codegen/gen_c.h b/ortools/math_opt/elemental/codegen/gen_c.h new file mode 100644 index 0000000000..8850f5f32d --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_c.h @@ -0,0 +1,32 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The C99 code generator. +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_ + +#include + +#include "ortools/math_opt/elemental/codegen/gen.h" + +namespace operations_research::math_opt::codegen { + +// Returns a generator for C99 declarations. +std::unique_ptr C99Declarations(); + +// Returns a generator for C99 definitions. +std::unique_ptr C99Definitions(); + +} // namespace operations_research::math_opt::codegen + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_ diff --git a/ortools/math_opt/elemental/codegen/gen_c_test.cc b/ortools/math_opt/elemental/codegen/gen_c_test.cc new file mode 100644 index 0000000000..3ed2b89d62 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_c_test.cc @@ -0,0 +1,100 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/codegen/gen_c.h" + +#include + +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" +#include "ortools/math_opt/elemental/codegen/gen.h" +#include "ortools/math_opt/elemental/codegen/testing.h" + +namespace operations_research::math_opt::codegen { +namespace { + +TEST(GenC99DeclarationsTest, EmitElements) { + std::string code; + C99Declarations()->EmitElements({"some_name", "other_name"}, &code); + EXPECT_EQ(code, + R"(// The type of an element in the model. +enum MathOptElementType { + kMathOptSomeName, + kMathOptOtherName, +}; + +)"); +} + +TEST(GenC99DeclarationsTest, StartAttrType) { + std::string code; + C99Declarations()->StartAttrType(GetTestDescriptor(), &code); + EXPECT_EQ(code, + R"(typedef enum { + kMathOptTestAttr2AName, + kMathOptTestAttr2BName, +} MathOptTestAttr2; +)"); +} + +TEST(GenC99DeclarationsTest, WithoutKey) { + std::string code; + C99Declarations()->EmitAttrOp("Op", GetTestDescriptor(), + GetTestFunctionInfo(false), &code); + EXPECT_EQ( + code, + R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, ExtraParam extra_param); +)"); +} + +TEST(GenC99DeclarationsTest, WithKey) { + std::string code; + C99Declarations()->EmitAttrOp("Op", GetTestDescriptor(), + GetTestFunctionInfo(true), &code); + EXPECT_EQ( + code, + R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, int64_t key_0, int64_t key_1, ExtraParam extra_param); +)"); +} + +TEST(GenC99DefinitionsTest, WithoutKey) { + std::string code; + C99Definitions()->EmitAttrOp("Op", GetTestDescriptor(), + GetTestFunctionInfo(false), &code); + EXPECT_EQ( + code, + R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, ExtraParam extra_param) { + return CAttrOp::Op(e, attr, extra_param); +} +)"); +} + +TEST(GenC99DefinitionsTest, WithKey) { + std::string code; + C99Definitions()->EmitAttrOp("Op", GetTestDescriptor(), + GetTestFunctionInfo(true), &code); + EXPECT_EQ( + code, + R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, int64_t key_0, int64_t key_1, ExtraParam extra_param) { + return CAttrOp::Op(e, attr, AttrKey<2, SomeSymmetry>(key_0, key_1), extra_param); +} +)"); +} + +TEST(GenC99DefinitionsTest, StartAttrType) { + std::string code; + C99Definitions()->StartAttrType(GetTestDescriptor(), &code); + EXPECT_EQ(code, ""); +} +} // namespace +} // namespace operations_research::math_opt::codegen diff --git a/ortools/math_opt/elemental/codegen/gen_python.cc b/ortools/math_opt/elemental/codegen/gen_python.cc new file mode 100644 index 0000000000..bd23aae40d --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_python.cc @@ -0,0 +1,161 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/codegen/gen_python.h" + +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/math_opt/elemental/codegen/gen.h" + +namespace operations_research::math_opt::codegen { + +namespace { + +const AttrOpFunctionInfos* GetPythonFunctionInfos() { + // We're not generating functions for python, only enums. + static const auto* const kResult = new AttrOpFunctionInfos(); + return kResult; +} + +// Emits a set of numbered python enumerators for the given range. +void EmitEnumerators(const absl::Span names, + std::string* out) { + for (int i = 0; i < names.size(); ++i) { + absl::StrAppendFormat(out, " %s = %i\n", absl::AsciiStrToUpper(names[i]), + i); + } +} + +// Returns the python type for the given value type. +absl::string_view GetAttrPyValueType( + const CodegenAttrTypeDescriptor::ValueType& value_type) { + switch (value_type) { + case CodegenAttrTypeDescriptor::ValueType::kBool: + return "bool"; + case CodegenAttrTypeDescriptor::ValueType::kInt64: + return "int"; + case CodegenAttrTypeDescriptor::ValueType::kDouble: + return "float"; + } +} + +// Returns the python type for the given value type. +absl::string_view GetAttrNumpyValueType( + const CodegenAttrTypeDescriptor::ValueType& value_type) { + switch (value_type) { + case CodegenAttrTypeDescriptor::ValueType::kBool: + return "np.bool_"; + case CodegenAttrTypeDescriptor::ValueType::kInt64: + return "np.int64"; + case CodegenAttrTypeDescriptor::ValueType::kDouble: + return "np.float64"; + } +} + +class PythonEnumsGenerator : public CodeGenerator { + public: + PythonEnumsGenerator() : CodeGenerator(GetPythonFunctionInfos()) {} + + void EmitHeader(std::string* out) const override { + absl::StrAppend(out, R"( +'''DO NOT EDIT: This file is autogenerated.''' + +import enum +from typing import Generic, TypeVar, Union + +import numpy as np +)"); + } + + void EmitElements(absl::Span elements, + std::string* out) const override { + // Generate an enum for the elements. + absl::StrAppend(out, "class ElementType(enum.Enum):\n"); + EmitEnumerators(elements, out); + absl::StrAppend(out, "\n"); + } + + void EmitAttributes(absl::Span descriptors, + std::string* out) const override { + absl::StrAppend(out, "\n"); + + { + // Collect the list of unique types: + std::set value_types; + for (const auto& descriptor : descriptors) { + value_types.insert(GetAttrNumpyValueType(descriptor.value_type)); + } + + // Emit `AttrValueType`, a type variable for all attribute value types. + absl::StrAppend(out, "AttrValueType = TypeVar('AttrValueType', ", + absl::StrJoin(value_types, ", "), ")\n"); + } + absl::StrAppend(out, "\n"); + { + std::set py_value_types; + for (const auto& descriptor : descriptors) { + py_value_types.insert(GetAttrPyValueType(descriptor.value_type)); + } + absl::StrAppend(out, "AttrPyValueType = TypeVar('AttrPyValueType', ", + absl::StrJoin(py_value_types, ", "), ")\n"); + } + + // `Attr` is an attribute with any value type. + absl::StrAppend(out, R"( +class Attr(Generic[AttrValueType]): + pass +)"); + + // `PyAttr` is an attribute with any value type. + absl::StrAppend(out, R"( +class PyAttr(Generic[AttrPyValueType]): + pass +)"); + + // Generate an enum for the attribute type. + for (const auto& descriptor : descriptors) { + absl::StrAppendFormat( + out, "\nclass %s(Attr[%s], PyAttr[%s], int, enum.Enum):\n", + descriptor.name, GetAttrNumpyValueType(descriptor.value_type), + GetAttrPyValueType(descriptor.value_type)); + EmitEnumerators(descriptor.attribute_names, out); + absl::StrAppend(out, "\n"); + } + + // Add a type alias for the union of all attribute types. + absl::StrAppend( + out, "AnyAttr = Union[", + absl::StrJoin( + descriptors, ", ", + [](std::string* out, const CodegenAttrTypeDescriptor& descriptor) { + absl::StrAppend(out, descriptor.name); + }), + "]\n"); + } +}; + +} // namespace + +std::unique_ptr PythonEnums() { + return std::make_unique(); +} + +} // namespace operations_research::math_opt::codegen diff --git a/ortools/math_opt/elemental/codegen/gen_python.h b/ortools/math_opt/elemental/codegen/gen_python.h new file mode 100644 index 0000000000..309ce17193 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_python.h @@ -0,0 +1,30 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The python code generator. +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_ + +#include + +#include "ortools/math_opt/elemental/codegen/gen.h" + +namespace operations_research::math_opt::codegen { + +// Returns a generator for python enums, independent of the actual +// implementation. These are used by the protocol. +std::unique_ptr PythonEnums(); + +} // namespace operations_research::math_opt::codegen + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_ diff --git a/ortools/math_opt/elemental/codegen/gen_python_test.cc b/ortools/math_opt/elemental/codegen/gen_python_test.cc new file mode 100644 index 0000000000..2c13d64a1d --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_python_test.cc @@ -0,0 +1,60 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/codegen/gen_python.h" + +#include + +#include "gtest/gtest.h" +#include "ortools/math_opt/elemental/codegen/testing.h" + +namespace operations_research::math_opt::codegen { +namespace { + +TEST(GenPythonEnumsTest, EmitElements) { + std::string code; + PythonEnums()->EmitElements({"some_name", "other_name"}, &code); + EXPECT_EQ(code, + R"(class ElementType(enum.Enum): + SOME_NAME = 0 + OTHER_NAME = 1 + +)"); +} + +TEST(GenPythonEnumsTest, EmitAttributes) { + std::string code; + PythonEnums()->EmitAttributes({GetTestDescriptor()}, &code); + EXPECT_EQ(code, + R"( +AttrValueType = TypeVar('AttrValueType', np.float64) + +AttrPyValueType = TypeVar('AttrPyValueType', float) + +class Attr(Generic[AttrValueType]): + pass + +class PyAttr(Generic[AttrPyValueType]): + pass + +class TestAttr2(Attr[np.float64], PyAttr[float], int, enum.Enum): + A_NAME = 0 + B_NAME = 1 + +AnyAttr = Union[TestAttr2] +)"); +} + +} // namespace + +} // namespace operations_research::math_opt::codegen diff --git a/ortools/math_opt/elemental/codegen/gen_test.cc b/ortools/math_opt/elemental/codegen/gen_test.cc new file mode 100644 index 0000000000..8d51750d89 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/gen_test.cc @@ -0,0 +1,96 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/elemental/codegen/gen.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" + +namespace operations_research::math_opt::codegen { +namespace { +using testing::HasSubstr; +using testing::StartsWith; + +const AttrOpFunctionInfos* GetFunctionInfos() { + static const auto* const kResult = new AttrOpFunctionInfos({ + {.return_type = Type::Named("TypeForGet"), + .has_key_parameter = false, + .extra_parameters = {}}, + {.return_type = Type::Pointer(Type::AttrValueType()), + .has_key_parameter = true, + .extra_parameters = {}}, + {.return_type = Type::Named("T"), + .has_key_parameter = false, + .extra_parameters = {}}, + {.return_type = Type::Named("T"), + .has_key_parameter = false, + .extra_parameters = {}}, + {.return_type = Type::Named("T"), + .has_key_parameter = false, + .extra_parameters = {}}, + }); + return kResult; +} + +class TestCodeGenerator : public CodeGenerator { + public: + TestCodeGenerator() : CodeGenerator(GetFunctionInfos()) {} + + void EmitHeader(std::string* out) const override { + absl::StrAppend(out, "# DO NOT EDIT: Test\n"); + } + + void EmitElements(absl::Span elements, + std::string* out) const override { + absl::StrAppend(out, "Elements: ", absl::StrJoin(elements, ", "), "\n"); + } + + void StartAttrType(const CodegenAttrTypeDescriptor&, + std::string* out) const override { + absl::StrAppend(out, "\n"); + } + + void EmitAttrOp(absl::string_view op_name, + const CodegenAttrTypeDescriptor& descriptor, + const AttrOpFunctionInfo& info, + std::string* out) const override { + info.return_type->Print("fake_type", out); + absl::StrAppend(out, " ", descriptor.name, op_name, "\n"); + } +}; + +TEST(GenerateCodeTest, Attrs) { + const std::string code = TestCodeGenerator().GenerateCode(); + EXPECT_THAT(code, StartsWith("# DO NOT EDIT: Test\n")); + EXPECT_THAT(code, HasSubstr("Elements: variable, linear_constraint, ")); + EXPECT_THAT(code, HasSubstr("TypeForGet BoolAttr0Get\n" + "fake_type* BoolAttr0Set\n" + "T BoolAttr0IsNonDefault\n" + "T BoolAttr0NumNonDefaults\n" + "T BoolAttr0GetNonDefaults\n")); + EXPECT_THAT(code, HasSubstr("TypeForGet DoubleAttr1Get\n" + "fake_type* DoubleAttr1Set\n" + "T DoubleAttr1IsNonDefault\n" + "T DoubleAttr1NumNonDefaults\n" + "T DoubleAttr1GetNonDefaults\n")); +} + +} // namespace +} // namespace operations_research::math_opt::codegen diff --git a/ortools/math_opt/elemental/codegen/testing.h b/ortools/math_opt/elemental/codegen/testing.h new file mode 100644 index 0000000000..c9cb147232 --- /dev/null +++ b/ortools/math_opt/elemental/codegen/testing.h @@ -0,0 +1,40 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test descriptors. This avoids depending on attributes from `attributes.h` +// in the tests to decouple the codegen tests from `attributes.h`. +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_ + +#include "ortools/math_opt/elemental/codegen/gen.h" + +namespace operations_research::math_opt::codegen { + +inline CodegenAttrTypeDescriptor GetTestDescriptor() { + return {.name = "TestAttr2", + .value_type = CodegenAttrTypeDescriptor::ValueType::kDouble, + .num_key_elements = 2, + .symmetry = "SomeSymmetry", + .attribute_names = {"a_name", "b_name"}}; +} + +inline AttrOpFunctionInfo GetTestFunctionInfo(bool with_key_parameter) { + return {.return_type = Type::Named("ReturnType"), + .has_key_parameter = with_key_parameter, + .extra_parameters = { + {{.type = Type::Named("ExtraParam"), .name = "extra_param"}}}}; +} + +} // namespace operations_research::math_opt::codegen + +#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_ diff --git a/ortools/math_opt/elemental/derived_data.h b/ortools/math_opt/elemental/derived_data.h new file mode 100644 index 0000000000..7ae7ca0358 --- /dev/null +++ b/ortools/math_opt/elemental/derived_data.h @@ -0,0 +1,218 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_ +#define OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_ + +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "ortools/math_opt/elemental/arrays.h" +#include "ortools/math_opt/elemental/attr_key.h" +#include "ortools/math_opt/elemental/attributes.h" +#include "ortools/math_opt/elemental/elements.h" +#include "ortools/util/fp_roundtrip_conv.h" + +namespace operations_research::math_opt { + +// A helper to manipulate the list of attributes. +struct AllAttrs { + // The number of available attribute types. + static constexpr int kNumAttrTypes = + std::tuple_size_v; + + // Returns the descriptor of the `i-th` attribute type in the list. + template + using TypeDescriptor = std::tuple_element_t; + + // Returns the `i-th` attribute type in the list. + template + using Type = typename TypeDescriptor::AttrType; + + // Returns the index of attribute type `AttrT`. + // Fails to compile if `AttrT` is not an attribute. + template + static constexpr int GetIndex() { + constexpr int index = GetIndexIfAttr(); + // This weird construct is to show `AttrT` explicitly instead of letting + // the user fish it out of the stack trace when the static_assert fails. + static_assert( + std::is_const_v= 0), const AttrT, AttrT>>, + "no such attribute"); + return index; + } + + // Applies `fn` on each value for each attribute type. `fn` must have a + // overload set of `operator(AttrType)` that accepts a `AttrType` for + // each attribute type. + template + static void ForEachAttr(Fn&& fn) { + ForEach( + [&fn](const auto& descriptor) { + for (auto attr : descriptor.Enumerate()) { + fn(attr); + } + }, + AllAttrTypeDescriptors{}); + } +}; + +// Returns the descriptor for attribute `AttrT`. +template +using AttrTypeDescriptorT = + AllAttrs::TypeDescriptor()>; + +// Returns the default value for the attribute type `attr`. +// +// For example GetAttrDefaultValue() returns 0.0. +template +constexpr typename AttrTypeDescriptorT::ValueType +GetAttrDefaultValue() { + return AttrTypeDescriptorT::kAttrDescriptors[static_cast( + attr)] + .default_value; +} + +// Returns the number of elements in a key for the attribute type `AttrType`. +// +// For example `GetAttrKeySize()` returns 2. +template +constexpr int GetAttrKeySize() { + return AttrTypeDescriptorT::kNumKeyElements; +} +template +constexpr int GetAttrKeySize() { + return GetAttrKeySize(); +} + +// The type of the `AttrKey` for attribute type `AttrType`. +template +using AttrKeyFor = AttrKey::kNumKeyElements, + typename AttrTypeDescriptorT::Symmetry>; + +// The value type for attribute type `AttrType`. +template +using ValueTypeFor = typename AttrTypeDescriptorT::ValueType; + +// Returns the array of elements for the key for the attribute type `attr`. +// +// For example, GetElementTypes() returns the array +// {ElementType::kLinearConstraint, ElementType::kVariable}. +template +constexpr std::array()> GetElementTypes( + const AttrType attr) { + return AttrTypeDescriptorT::kAttrDescriptors[static_cast(attr)] + .key_types; +} +template +constexpr std::array()> GetElementTypes() { + return GetElementTypes(attr); +} + +// After C++20, this can be replaced by a lambda. C++17 does not allow lambdas +// in unevaluated contexts. +template