diff --git a/ortools/math_opt/BUILD.bazel b/ortools/math_opt/BUILD.bazel index 1ea2bcfa53..0560e2be2b 100644 --- a/ortools/math_opt/BUILD.bazel +++ b/ortools/math_opt/BUILD.bazel @@ -11,8 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") load("@rules_python//python:proto.bzl", "py_proto_library") package(default_visibility = ["//visibility:public"]) diff --git a/ortools/math_opt/python/BUILD.bazel b/ortools/math_opt/python/BUILD.bazel index 170edd664b..cae132a5fd 100644 --- a/ortools/math_opt/python/BUILD.bazel +++ b/ortools/math_opt/python/BUILD.bazel @@ -28,6 +28,7 @@ py_library( ":errors", ":expressions", ":hash_model_storage", + ":init_arguments", ":message_callback", ":model", ":model_parameters", @@ -160,6 +161,7 @@ py_library( ":callback", ":compute_infeasible_subsystem_result", ":errors", + ":init_arguments", ":message_callback", ":model", ":model_parameters", @@ -210,3 +212,12 @@ py_library( srcs = ["errors.py"], deps = ["//ortools/math_opt:rpc_py_pb2"], ) + +py_library( + name = "init_arguments", + srcs = ["init_arguments.py"], + deps = [ + "//ortools/math_opt:parameters_py_pb2", + "//ortools/math_opt/solvers:gurobi_py_pb2", + ], +) diff --git a/ortools/math_opt/python/init_arguments.py b/ortools/math_opt/python/init_arguments.py new file mode 100644 index 0000000000..cfaab67d82 --- /dev/null +++ b/ortools/math_opt/python/init_arguments.py @@ -0,0 +1,180 @@ +# Copyright 2010-2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configures the instantiation of the underlying solver.""" + +import dataclasses +from typing import Optional + +from ortools.math_opt import parameters_pb2 +from ortools.math_opt.solvers import gurobi_pb2 + + +@dataclasses.dataclass +class StreamableGScipInitArguments: + """Streamable GScip specific parameters for solver instantiation.""" + + +@dataclasses.dataclass(frozen=True) +class GurobiISVKey: + """The Gurobi ISV key, an alternative to license files. + + Contact Gurobi for details. + + Attributes: + name: A string, typically a company/organization. + application_name: A string, typically a project. + expiration: An int, a value of 0 indicates no expiration. + key: A string, the secret. + """ + + name: str = "" + application_name: str = "" + expiration: int = 0 + key: str = "" + + def to_proto(self) -> gurobi_pb2.GurobiInitializerProto.ISVKey: + """Returns a protocol buffer equivalent of this.""" + return gurobi_pb2.GurobiInitializerProto.ISVKey( + name=self.name, + application_name=self.application_name, + expiration=self.expiration, + key=self.key, + ) + + +def gurobi_isv_key_from_proto( + proto: gurobi_pb2.GurobiInitializerProto.ISVKey, +) -> GurobiISVKey: + """Returns an equivalent GurobiISVKey to the input proto.""" + return GurobiISVKey( + name=proto.name, + application_name=proto.application_name, + expiration=proto.expiration, + key=proto.key, + ) + + +@dataclasses.dataclass +class StreamableGurobiInitArguments: + """Streamable Gurobi specific parameters for solver instantiation.""" + + isv_key: Optional[GurobiISVKey] = None + + def to_proto(self) -> gurobi_pb2.GurobiInitializerProto: + """Returns a protocol buffer equivalent of this.""" + return gurobi_pb2.GurobiInitializerProto( + isv_key=self.isv_key.to_proto() if self.isv_key else None + ) + + +def streamable_gurobi_init_arguments_from_proto( + proto: gurobi_pb2.GurobiInitializerProto, +) -> StreamableGurobiInitArguments: + """Returns an equivalent StreamableGurobiInitArguments to the input proto.""" + result = StreamableGurobiInitArguments() + if proto.HasField("isv_key"): + result.isv_key = gurobi_isv_key_from_proto(proto.isv_key) + return result + + +@dataclasses.dataclass +class StreamableGlopInitArguments: + """Streamable Glop specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableCpSatInitArguments: + """Streamable CP-SAT specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamablePdlpInitArguments: + """Streamable Pdlp specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableGlpkInitArguments: + """Streamable GLPK specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableOsqpInitArguments: + """Streamable OSQP specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableEcosInitArguments: + """Streamable Ecos specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableScsInitArguments: + """Streamable Scs specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableHighsInitArguments: + """Streamable Highs specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableSantoriniInitArguments: + """Streamable Santorini specific parameters for solver instantiation.""" + + +@dataclasses.dataclass +class StreamableSolverInitArguments: + """Solver initialization parameters that can be sent to another process. + + Attributes: + gscip: Initialization parameters specific to GScip. + gurobi: Initialization parameters specific to Gurobi. + glop: Initialization parameters specific to GLOP. + cp_sat: Initialization parameters specific to CP-SAT. + pdlp: Initialization parameters specific to PDLP. + glpk: Initialization parameters specific to GLPK. + osqp: Initialization parameters specific to OSQP. + ecos: Initialization parameters specific to ECOS. + scs: Initialization parameters specific to SCS. + highs: Initialization parameters specific to HiGHS. + santorini: Initialization parameters specific to Santorini. + """ + + gscip: Optional[StreamableGScipInitArguments] = None + gurobi: Optional[StreamableGurobiInitArguments] = None + glop: Optional[StreamableGlopInitArguments] = None + cp_sat: Optional[StreamableCpSatInitArguments] = None + pdlp: Optional[StreamablePdlpInitArguments] = None + glpk: Optional[StreamableGlpkInitArguments] = None + osqp: Optional[StreamableOsqpInitArguments] = None + ecos: Optional[StreamableEcosInitArguments] = None + scs: Optional[StreamableScsInitArguments] = None + highs: Optional[StreamableHighsInitArguments] = None + santorini: Optional[StreamableSantoriniInitArguments] = None + + def to_proto(self) -> parameters_pb2.SolverInitializerProto: + """Returns a protocol buffer equivalent of this.""" + return parameters_pb2.SolverInitializerProto( + gurobi=self.gurobi.to_proto() if self.gurobi else None + ) + + +def streamable_solver_init_arguments_from_proto( + proto: parameters_pb2.SolverInitializerProto, +) -> StreamableSolverInitArguments: + """Returns an equivalent StreamableSolverInitArguments to the input proto.""" + result = StreamableSolverInitArguments() + if proto.HasField("gurobi"): + result.gurobi = streamable_gurobi_init_arguments_from_proto(proto.gurobi) + return result diff --git a/ortools/math_opt/python/init_arguments_test.py b/ortools/math_opt/python/init_arguments_test.py new file mode 100644 index 0000000000..eefcbf6f1a --- /dev/null +++ b/ortools/math_opt/python/init_arguments_test.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2010-2024 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from ortools.math_opt import parameters_pb2 +from ortools.math_opt.python import init_arguments +from ortools.math_opt.python.testing import compare_proto +from ortools.math_opt.solvers import gurobi_pb2 + + +class GurobiISVKeyTest(absltest.TestCase, compare_proto.MathOptProtoAssertions): + + def test_proto_conversions(self) -> None: + isv = init_arguments.GurobiISVKey( + name="cat", application_name="hat", expiration=4, key="bat" + ) + proto_isv = gurobi_pb2.GurobiInitializerProto.ISVKey( + name="cat", application_name="hat", expiration=4, key="bat" + ) + self.assert_protos_equiv(isv.to_proto(), proto_isv) + self.assertEqual(init_arguments.gurobi_isv_key_from_proto(proto_isv), isv) + + +class StreamableGurobiInitArgumentsTest( + absltest.TestCase, compare_proto.MathOptProtoAssertions +): + + def test_proto_conversions_isv_key_set(self) -> None: + init = init_arguments.StreamableGurobiInitArguments( + isv_key=init_arguments.GurobiISVKey( + name="cat", application_name="hat", expiration=4, key="bat" + ) + ) + proto_init = gurobi_pb2.GurobiInitializerProto( + isv_key=gurobi_pb2.GurobiInitializerProto.ISVKey( + name="cat", application_name="hat", expiration=4, key="bat" + ) + ) + self.assert_protos_equiv(init.to_proto(), proto_init) + self.assertEqual( + init_arguments.streamable_gurobi_init_arguments_from_proto(proto_init), + init, + ) + + def test_proto_conversions_isv_key_not_set(self) -> None: + init = init_arguments.StreamableGurobiInitArguments() + proto_init = gurobi_pb2.GurobiInitializerProto() + self.assert_protos_equiv(init.to_proto(), proto_init) + self.assertEqual( + init_arguments.streamable_gurobi_init_arguments_from_proto(proto_init), + init, + ) + + +class StreamableSolverInitArgumentsTest( + absltest.TestCase, compare_proto.MathOptProtoAssertions +): + + def test_proto_conversions_gurobi_set(self) -> None: + init = init_arguments.StreamableSolverInitArguments( + gurobi=init_arguments.StreamableGurobiInitArguments( + isv_key=init_arguments.GurobiISVKey( + name="cat", application_name="hat", expiration=4, key="bat" + ) + ) + ) + proto_init = parameters_pb2.SolverInitializerProto( + gurobi=gurobi_pb2.GurobiInitializerProto( + isv_key=gurobi_pb2.GurobiInitializerProto.ISVKey( + name="cat", application_name="hat", expiration=4, key="bat" + ) + ) + ) + self.assert_protos_equiv(init.to_proto(), proto_init) + self.assertEqual( + init_arguments.streamable_solver_init_arguments_from_proto(proto_init), + init, + ) + + def test_proto_conversions_gurobi_not_set(self) -> None: + init = init_arguments.StreamableSolverInitArguments() + proto_init = parameters_pb2.SolverInitializerProto() + self.assert_protos_equiv(init.to_proto(), proto_init) + self.assertEqual( + init_arguments.streamable_solver_init_arguments_from_proto(proto_init), + init, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/ortools/math_opt/python/mathopt.py b/ortools/math_opt/python/mathopt.py index 8191122344..4b48dfe33b 100644 --- a/ortools/math_opt/python/mathopt.py +++ b/ortools/math_opt/python/mathopt.py @@ -65,6 +65,26 @@ from ortools.math_opt.python.errors import status_proto_to_exception from ortools.math_opt.python.expressions import evaluate_expression from ortools.math_opt.python.expressions import fast_sum from ortools.math_opt.python.hash_model_storage import HashModelStorage +from ortools.math_opt.python.init_arguments import gurobi_isv_key_from_proto +from ortools.math_opt.python.init_arguments import GurobiISVKey +from ortools.math_opt.python.init_arguments import ( + streamable_gurobi_init_arguments_from_proto, +) +from ortools.math_opt.python.init_arguments import ( + streamable_solver_init_arguments_from_proto, +) +from ortools.math_opt.python.init_arguments import StreamableCpSatInitArguments +from ortools.math_opt.python.init_arguments import StreamableEcosInitArguments +from ortools.math_opt.python.init_arguments import StreamableGlopInitArguments +from ortools.math_opt.python.init_arguments import StreamableGlpkInitArguments +from ortools.math_opt.python.init_arguments import StreamableGScipInitArguments +from ortools.math_opt.python.init_arguments import StreamableGurobiInitArguments +from ortools.math_opt.python.init_arguments import StreamableHighsInitArguments +from ortools.math_opt.python.init_arguments import StreamableOsqpInitArguments +from ortools.math_opt.python.init_arguments import StreamablePdlpInitArguments +from ortools.math_opt.python.init_arguments import StreamableSantoriniInitArguments +from ortools.math_opt.python.init_arguments import StreamableScsInitArguments +from ortools.math_opt.python.init_arguments import StreamableSolverInitArguments from ortools.math_opt.python.message_callback import list_message_callback from ortools.math_opt.python.message_callback import log_messages from ortools.math_opt.python.message_callback import printer_message_callback diff --git a/ortools/math_opt/python/mathopt_test.py b/ortools/math_opt/python/mathopt_test.py index 259e3e8e1f..08fed7c73d 100644 --- a/ortools/math_opt/python/mathopt_test.py +++ b/ortools/math_opt/python/mathopt_test.py @@ -22,6 +22,7 @@ from absl.testing import absltest from ortools.math_opt.python import callback from ortools.math_opt.python import expressions from ortools.math_opt.python import hash_model_storage +from ortools.math_opt.python import init_arguments from ortools.math_opt.python import mathopt from ortools.math_opt.python import message_callback from ortools.math_opt.python import model @@ -48,6 +49,7 @@ _MODULES_TO_CHECK: List[types.ModuleType] = [ callback, expressions, hash_model_storage, + init_arguments, message_callback, model, model_parameters, diff --git a/ortools/math_opt/python/result.py b/ortools/math_opt/python/result.py index 27988da834..f8d804c469 100644 --- a/ortools/math_opt/python/result.py +++ b/ortools/math_opt/python/result.py @@ -325,6 +325,18 @@ class Termination: problem_status: ProblemStatus = ProblemStatus() objective_bounds: ObjectiveBounds = ObjectiveBounds() + def to_proto(self) -> result_pb2.TerminationProto: + """Returns an equivalent protocol buffer to this Termination.""" + return result_pb2.TerminationProto( + reason=self.reason.value, + limit=( + result_pb2.LIMIT_UNSPECIFIED if self.limit is None else self.limit.value + ), + detail=self.detail, + problem_status=self.problem_status.to_proto(), + objective_bounds=self.objective_bounds.to_proto(), + ) + def parse_termination( termination_proto: result_pb2.TerminationProto, @@ -930,6 +942,39 @@ class SolveResult: f"variable_status: {type(variables).__name__!r}" ) + def to_proto(self) -> result_pb2.SolveResultProto: + """Returns an equivalent protocol buffer for a SolveResult.""" + proto = result_pb2.SolveResultProto( + termination=self.termination.to_proto(), + solutions=[s.to_proto() for s in self.solutions], + primal_rays=[r.to_proto() for r in self.primal_rays], + dual_rays=[r.to_proto() for r in self.dual_rays], + solve_stats=self.solve_stats.to_proto(), + ) + + # Ensure that at most solver has solver specific output. + existing_solver_specific_output = None + + def has_solver_specific_output(solver_name: str) -> None: + nonlocal existing_solver_specific_output + if existing_solver_specific_output is not None: + raise ValueError( + "found solver specific output for both" + f" {existing_solver_specific_output} and {solver_name}" + ) + existing_solver_specific_output = solver_name + + if self.gscip_specific_output is not None: + has_solver_specific_output("gscip") + proto.gscip_output.CopyFrom(self.gscip_specific_output) + if self.osqp_specific_output is not None: + has_solver_specific_output("osqp") + proto.osqp_output.CopyFrom(self.osqp_specific_output) + if self.pdlp_specific_output is not None: + has_solver_specific_output("pdlp") + proto.pdlp_output.CopyFrom(self.pdlp_specific_output) + return proto + def _get_problem_status( result_proto: result_pb2.SolveResultProto, diff --git a/ortools/math_opt/python/result_test.py b/ortools/math_opt/python/result_test.py index 6c4e7beeb1..c2b0830a62 100644 --- a/ortools/math_opt/python/result_test.py +++ b/ortools/math_opt/python/result_test.py @@ -16,6 +16,8 @@ import datetime import math from absl.testing import absltest +from ortools.pdlp import solve_log_pb2 +from ortools.gscip import gscip_pb2 from ortools.math_opt import result_pb2 from ortools.math_opt import solution_pb2 from ortools.math_opt import sparse_containers_pb2 @@ -23,9 +25,10 @@ from ortools.math_opt.python import model from ortools.math_opt.python import result from ortools.math_opt.python import solution from ortools.math_opt.python.testing import compare_proto +from ortools.math_opt.solvers import osqp_pb2 -class ParseTerminationReason(compare_proto.MathOptProtoAssertions, absltest.TestCase): +class TerminationTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): def test_termination_unspecified(self) -> None: termination_proto = result_pb2.TerminationProto( @@ -54,7 +57,19 @@ class ParseTerminationReason(compare_proto.MathOptProtoAssertions, absltest.Test ): result.parse_termination(termination_proto) - def test_termination_ok(self) -> None: + def test_termination_ok_proto_round_trip(self) -> None: + termination = result.Termination( + reason=result.TerminationReason.NO_SOLUTION_FOUND, + limit=result.Limit.OTHER, + detail="detail", + problem_status=result.ProblemStatus( + primal_status=result.FeasibilityStatus.FEASIBLE, + dual_status=result.FeasibilityStatus.INFEASIBLE, + primal_or_dual_infeasible=False, + ), + objective_bounds=result.ObjectiveBounds(primal_bound=10, dual_bound=20), + ) + termination_proto = result_pb2.TerminationProto( reason=result_pb2.TERMINATION_REASON_NO_SOLUTION_FOUND, limit=result_pb2.LIMIT_OTHER, @@ -68,22 +83,12 @@ class ParseTerminationReason(compare_proto.MathOptProtoAssertions, absltest.Test primal_bound=10, dual_bound=20 ), ) - termination = result.parse_termination(termination_proto) - self.assertEqual(termination.reason, result.TerminationReason.NO_SOLUTION_FOUND) - self.assertEqual(termination.limit, result.Limit.OTHER) - self.assertEqual(termination.detail, "detail") - self.assertEqual( - termination.problem_status, - result.ProblemStatus( - primal_status=result.FeasibilityStatus.FEASIBLE, - dual_status=result.FeasibilityStatus.INFEASIBLE, - primal_or_dual_infeasible=False, - ), - ) - self.assertEqual( - termination.objective_bounds, - result.ObjectiveBounds(primal_bound=10, dual_bound=20), - ) + + # Test proto-> Termination + self.assertEqual(result.parse_termination(termination_proto), termination) + + # Test Termination -> proto + self.assert_protos_equiv(termination.to_proto(), termination_proto) class ParseProblemStatus(compare_proto.MathOptProtoAssertions, absltest.TestCase): @@ -617,83 +622,105 @@ def _make_undetermined_result_proto() -> result_pb2.SolveResultProto: primal_bound=math.inf, dual_bound=-math.inf, ), - ), - solutions=[ - solution_pb2.SolutionProto( - primal_solution=solution_pb2.PrimalSolutionProto( - objective_value=2.0, - variable_values=sparse_containers_pb2.SparseDoubleVectorProto( - ids=[0], values=[1.0] - ), - feasibility_status=solution_pb2.SOLUTION_STATUS_UNDETERMINED, - ) - ) - ], + ) ) - proto.solve_stats.problem_status.primal_status = ( - result_pb2.FEASIBILITY_STATUS_UNDETERMINED - ) - proto.solve_stats.problem_status.dual_status = ( - result_pb2.FEASIBILITY_STATUS_UNDETERMINED - ) - proto.solve_stats.problem_status.primal_or_dual_infeasible = False - proto.solve_stats.best_primal_bound = math.inf - proto.solve_stats.best_dual_bound = -math.inf + proto.solve_stats.solve_time.FromTimedelta(datetime.timedelta(minutes=2)) return proto +def _make_undetermined_solve_result() -> result.SolveResult: + return result.SolveResult( + termination=result.Termination( + reason=result.TerminationReason.NO_SOLUTION_FOUND, + limit=result.Limit.TIME, + problem_status=result.ProblemStatus( + primal_status=result.FeasibilityStatus.UNDETERMINED, + dual_status=result.FeasibilityStatus.UNDETERMINED, + ), + objective_bounds=result.ObjectiveBounds( + primal_bound=math.inf, dual_bound=-math.inf + ), + ), + solve_stats=result.SolveStats(solve_time=datetime.timedelta(minutes=2)), + ) + + class SolveResultTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): def test_solve_result_gscip_output(self) -> None: mod = model.Model(name="test_model") - mod.add_binary_variable() + res = _make_undetermined_solve_result() + res.gscip_specific_output = gscip_pb2.GScipOutput(status_detail="gscip_detail") + proto = _make_undetermined_result_proto() proto.gscip_output.status_detail = "gscip_detail" - res = result.parse_solve_result(proto, mod) - assert res.gscip_specific_output is not None - self.assertEqual("gscip_detail", res.gscip_specific_output.status_detail) - def test_solve_result_no_gscip_output(self) -> None: - mod = model.Model(name="test_model") - mod.add_binary_variable() - proto = _make_undetermined_result_proto() - res = result.parse_solve_result(proto, mod) - self.assertIsNone(res.gscip_specific_output) + # proto -> result + actual_res = result.parse_solve_result(proto, mod) + self.assertIsNotNone(actual_res.gscip_specific_output) + assert actual_res.gscip_specific_output is not None + self.assertEqual("gscip_detail", actual_res.gscip_specific_output.status_detail) + self.assertIsNone(actual_res.pdlp_specific_output) + self.assertIsNone(actual_res.osqp_specific_output) + + # result -> proto + self.assert_protos_equiv(res.to_proto(), proto) def test_solve_result_osqp_output(self) -> None: mod = model.Model(name="test_model") - mod.add_binary_variable() - proto = _make_undetermined_result_proto() - proto.osqp_output.initialized_underlying_solver = False - res = result.parse_solve_result(proto, mod) - assert res.osqp_specific_output is not None - self.assertFalse(res.osqp_specific_output.initialized_underlying_solver) + res = _make_undetermined_solve_result() + res.osqp_specific_output = osqp_pb2.OsqpOutput( + initialized_underlying_solver=True + ) - def test_solve_result_no_osqp_output(self) -> None: - mod = model.Model(name="test_model") - mod.add_binary_variable() proto = _make_undetermined_result_proto() - res = result.parse_solve_result(proto, mod) - self.assertIsNone(res.osqp_specific_output) + proto.osqp_output.initialized_underlying_solver = True + + # proto -> result + actual_res = result.parse_solve_result(proto, mod) + self.assertIsNotNone(actual_res.osqp_specific_output) + assert actual_res.osqp_specific_output is not None + self.assertTrue(actual_res.osqp_specific_output.initialized_underlying_solver) + self.assertIsNone(actual_res.pdlp_specific_output) + self.assertIsNone(actual_res.gscip_specific_output) + + # result -> proto + self.assert_protos_equiv(res.to_proto(), proto) def test_solve_result_pdlp_output(self) -> None: mod = model.Model(name="test_model") - mod.add_binary_variable() - proto = _make_undetermined_result_proto() - proto.pdlp_output.convergence_information.corrected_dual_objective = 2.0 - res = result.parse_solve_result(proto, mod) - assert res.pdlp_specific_output is not None - self.assertEqual( - res.pdlp_specific_output.convergence_information.corrected_dual_objective, - 2.0, + res = _make_undetermined_solve_result() + res.pdlp_specific_output = result_pb2.SolveResultProto.PdlpOutput( + convergence_information=solve_log_pb2.ConvergenceInformation( + primal_objective=1.0 + ) ) - def test_solve_result_no_pdlp_output(self) -> None: - mod = model.Model(name="test_model") - mod.add_binary_variable() proto = _make_undetermined_result_proto() - res = result.parse_solve_result(proto, mod) - self.assertIsNone(res.pdlp_specific_output) + proto.pdlp_output.convergence_information.primal_objective = 1.0 + + # proto -> result + actual_res = result.parse_solve_result(proto, mod) + self.assertIsNotNone(actual_res.pdlp_specific_output) + assert actual_res.pdlp_specific_output is not None + self.assertEqual( + actual_res.pdlp_specific_output.convergence_information.primal_objective, + 1.0, + ) + self.assertIsNone(actual_res.osqp_specific_output) + self.assertIsNone(actual_res.gscip_specific_output) + + # result -> proto + self.assert_protos_equiv(res.to_proto(), proto) + + def test_multiple_solver_specific_outputs_error(self) -> None: + res = _make_undetermined_solve_result() + res.gscip_specific_output = gscip_pb2.GScipOutput(status_detail="gscip_detail") + res.osqp_specific_output = osqp_pb2.OsqpOutput( + initialized_underlying_solver=False + ) + with self.assertRaisesRegex(ValueError, "solver specific output"): + res.to_proto() def test_solve_result_from_proto_missing_bounds_in_termination( self, @@ -1048,6 +1075,78 @@ class SolveResultTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): self.assertEqual(20, res.termination.objective_bounds.dual_bound) self.assertIsNone(res.gscip_specific_output) + def test_to_proto_round_trip(self) -> None: + mod = model.Model(name="test_model") + x = mod.add_binary_variable(name="x") + c = mod.add_linear_constraint(lb=0.0, ub=1.0, name="c") + + s = solution.Solution( + primal_solution=solution.PrimalSolution( + variable_values={x: 1.0}, + objective_value=2.0, + feasibility_status=solution.SolutionStatus.FEASIBLE, + ) + ) + r = result.SolveResult( + termination=result.Termination( + reason=result.TerminationReason.FEASIBLE, + limit=result.Limit.TIME, + problem_status=result.ProblemStatus( + primal_status=result.FeasibilityStatus.FEASIBLE, + dual_status=result.FeasibilityStatus.UNDETERMINED, + ), + ), + solve_stats=result.SolveStats( + node_count=3, solve_time=datetime.timedelta(seconds=4) + ), + solutions=[s], + primal_rays=[solution.PrimalRay(variable_values={x: 4.0})], + dual_rays=[solution.DualRay(reduced_costs={x: 5.0}, dual_values={c: 6.0})], + ) + + s_proto = solution_pb2.SolutionProto( + primal_solution=solution_pb2.PrimalSolutionProto( + objective_value=2.0, + feasibility_status=solution_pb2.SOLUTION_STATUS_FEASIBLE, + variable_values=sparse_containers_pb2.SparseDoubleVectorProto( + ids=[0], values=[1.0] + ), + ) + ) + r_proto = result_pb2.SolveResultProto( + termination=result_pb2.TerminationProto( + reason=result_pb2.TERMINATION_REASON_FEASIBLE, + limit=result_pb2.LIMIT_TIME, + problem_status=result_pb2.ProblemStatusProto( + primal_status=result_pb2.FEASIBILITY_STATUS_FEASIBLE, + dual_status=result_pb2.FEASIBILITY_STATUS_UNDETERMINED, + ), + ), + solve_stats=result_pb2.SolveStatsProto(node_count=3), + solutions=[s_proto], + primal_rays=[ + solution_pb2.PrimalRayProto( + variable_values=sparse_containers_pb2.SparseDoubleVectorProto( + ids=[0], values=[4.0] + ) + ) + ], + dual_rays=[ + solution_pb2.DualRayProto( + reduced_costs=sparse_containers_pb2.SparseDoubleVectorProto( + ids=[0], values=[5.0] + ), + dual_values=sparse_containers_pb2.SparseDoubleVectorProto( + ids=[0], values=[6.0] + ), + ) + ], + ) + r_proto.solve_stats.solve_time.FromTimedelta(datetime.timedelta(seconds=4)) + + self.assert_protos_equiv(r.to_proto(), r_proto) + self.assertEqual(result.parse_solve_result(r_proto, mod), r) + if __name__ == "__main__": absltest.main() diff --git a/ortools/math_opt/python/solution.py b/ortools/math_opt/python/solution.py index f36596534f..bc2f21ba4b 100644 --- a/ortools/math_opt/python/solution.py +++ b/ortools/math_opt/python/solution.py @@ -164,6 +164,14 @@ class PrimalRay: default_factory=dict ) + def to_proto(self) -> solution_pb2.PrimalRayProto: + """Returns an equivalent proto to this PrimalRay.""" + return solution_pb2.PrimalRayProto( + variable_values=sparse_containers.to_sparse_double_vector_proto( + self.variable_values + ) + ) + def parse_primal_ray(proto: solution_pb2.PrimalRayProto, mod: model.Model) -> PrimalRay: """Returns an equivalent PrimalRay from the input proto.""" @@ -278,6 +286,17 @@ class DualRay: ) reduced_costs: Dict[model.Variable, float] = dataclasses.field(default_factory=dict) + def to_proto(self) -> solution_pb2.DualRayProto: + """Returns an equivalent proto to this PrimalRay.""" + return solution_pb2.DualRayProto( + dual_values=sparse_containers.to_sparse_double_vector_proto( + self.dual_values + ), + reduced_costs=sparse_containers.to_sparse_double_vector_proto( + self.reduced_costs + ), + ) + def parse_dual_ray(proto: solution_pb2.DualRayProto, mod: model.Model) -> DualRay: """Returns an equivalent DualRay from the input proto.""" diff --git a/ortools/math_opt/python/solution_test.py b/ortools/math_opt/python/solution_test.py index 49d8758b3b..a2203b89f7 100644 --- a/ortools/math_opt/python/solution_test.py +++ b/ortools/math_opt/python/solution_test.py @@ -76,17 +76,24 @@ class ParsePrimalSolutionTest(compare_proto.MathOptProtoAssertions, absltest.Tes solution.parse_primal_solution(proto, mod) -class ParsePrimalRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): +class PrimalRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): - def test_parse(self) -> None: + def test_proto_round_trip(self) -> None: mod = model.Model(name="test_model") x = mod.add_binary_variable(name="x") y = mod.add_binary_variable(name="y") - proto = solution_pb2.PrimalRayProto() - proto.variable_values.ids[:] = [0, 1] - proto.variable_values.values[:] = [1.0, 1.0] - actual = solution.parse_primal_ray(proto, mod) - self.assertDictEqual({x: 1.0, y: 1.0}, actual.variable_values) + ray = solution.PrimalRay(variable_values={x: 1.0, y: 1.0}) + ray_proto = solution_pb2.PrimalRayProto() + ray_proto.variable_values.ids[:] = [0, 1] + ray_proto.variable_values.values[:] = [1.0, 1.0] + + # Test proto -> model + parsed_ray = solution.parse_primal_ray(ray_proto, mod) + self.assertDictEqual({x: 1.0, y: 1.0}, parsed_ray.variable_values) + + # Test model -> proto + exported_ray = ray.to_proto() + self.assert_protos_equiv(exported_ray, ray_proto) class ParseDualSolutionTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): @@ -151,22 +158,33 @@ class ParseDualSolutionTest(compare_proto.MathOptProtoAssertions, absltest.TestC solution.parse_dual_solution(proto, mod) -class ParseDualRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): +class DualRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): - def test_parse(self) -> None: + def test_proto_round_trip(self) -> None: mod = model.Model(name="test_model") x = mod.add_binary_variable(name="x") y = mod.add_binary_variable(name="y") c = mod.add_linear_constraint(lb=0.0, ub=1.0, name="c") d = mod.add_linear_constraint(lb=0.0, ub=1.0, name="d") - proto = solution_pb2.DualRayProto() - proto.dual_values.ids[:] = [0, 1] - proto.dual_values.values[:] = [0.0, 1.0] - proto.reduced_costs.ids[:] = [0, 1] - proto.reduced_costs.values[:] = [10.0, 0.0] - actual = solution.parse_dual_ray(proto, mod) - self.assertDictEqual({x: 10.0, y: 0.0}, actual.reduced_costs) - self.assertDictEqual({c: 0.0, d: 1.0}, actual.dual_values) + + dual_ray = solution.DualRay( + dual_values={c: 0.0, d: 1.0}, reduced_costs={x: 10.0, y: 0.0} + ) + + dual_ray_proto = solution_pb2.DualRayProto() + dual_ray_proto.dual_values.ids[:] = [0, 1] + dual_ray_proto.dual_values.values[:] = [0.0, 1.0] + dual_ray_proto.reduced_costs.ids[:] = [0, 1] + dual_ray_proto.reduced_costs.values[:] = [10.0, 0.0] + + # Test proto -> dual ray + parsed_ray = solution.parse_dual_ray(dual_ray_proto, mod) + self.assertDictEqual(dual_ray.reduced_costs, parsed_ray.reduced_costs) + self.assertDictEqual(dual_ray.dual_values, parsed_ray.dual_values) + + # Test dual ray -> proto + exported_proto = dual_ray.to_proto() + self.assert_protos_equiv(exported_proto, dual_ray_proto) class BasisTest(compare_proto.MathOptProtoAssertions, absltest.TestCase): diff --git a/ortools/math_opt/python/solve.py b/ortools/math_opt/python/solve.py index 500453fe12..d078b1df85 100644 --- a/ortools/math_opt/python/solve.py +++ b/ortools/math_opt/python/solve.py @@ -21,6 +21,7 @@ from ortools.math_opt.core.python import solver from ortools.math_opt.python import callback from ortools.math_opt.python import compute_infeasible_subsystem_result from ortools.math_opt.python import errors +from ortools.math_opt.python import init_arguments from ortools.math_opt.python import message_callback from ortools.math_opt.python import model from ortools.math_opt.python import model_parameters @@ -40,6 +41,7 @@ def solve( msg_cb: Optional[message_callback.SolveMessageCallback] = None, callback_reg: Optional[callback.CallbackRegistration] = None, cb: Optional[SolveCallback] = None, + streamable_init_args: Optional[init_arguments.StreamableSolverInitArguments] = None, ) -> result.SolveResult: """Solves an optimization model. @@ -56,6 +58,7 @@ def solve( callback_reg: Configures when the callback will be invoked (if provided) and what data will be collected to access in the callback. cb: A callback that will be called periodically as the solver runs. + streamable_init_args: Configuration for initializing the underlying solver. Returns: A SolveResult containing the termination reason, solution(s) and stats. @@ -68,6 +71,9 @@ def solve( params = params or parameters.SolveParameters() model_params = model_params or model_parameters.ModelSolveParameters() callback_reg = callback_reg or callback.CallbackRegistration() + streamable_init_args = ( + streamable_init_args or init_arguments.StreamableSolverInitArguments() + ) model_proto = opt_model.export_model() proto_cb = None if cb is not None: @@ -79,7 +85,7 @@ def solve( proto_result = solver.solve( model_proto, solver_type.value, - parameters_pb2.SolverInitializerProto(), + streamable_init_args.to_proto(), params.to_proto(), model_params.to_proto(), msg_cb, @@ -98,6 +104,7 @@ def compute_infeasible_subsystem( *, params: Optional[parameters.SolveParameters] = None, msg_cb: Optional[message_callback.SolveMessageCallback] = None, + streamable_init_args: Optional[init_arguments.StreamableSolverInitArguments] = None, ) -> compute_infeasible_subsystem_result.ComputeInfeasibleSubsystemResult: """Computes an infeasible subsystem of the input model. @@ -107,6 +114,7 @@ def compute_infeasible_subsystem( August 2023, the only supported solver is Gurobi. params: Configuration of the underlying solver. msg_cb: A callback that gives back the underlying solver's logs by the line. + streamable_init_args: Configuration for initializing the underlying solver. Returns: An `ComputeInfeasibleSubsystemResult` where `feasibility` indicates if the @@ -116,13 +124,16 @@ def compute_infeasible_subsystem( RuntimeError: on invalid inputs or an internal solver error. """ params = params or parameters.SolveParameters() + streamable_init_args = ( + streamable_init_args or init_arguments.StreamableSolverInitArguments() + ) model_proto = opt_model.export_model() # Solve try: proto_result = solver.compute_infeasible_subsystem( model_proto, solver_type.value, - parameters_pb2.SolverInitializerProto(), + streamable_init_args.to_proto(), params.to_proto(), msg_cb, None, @@ -163,7 +174,18 @@ class IncrementalSolver: When it is not possible to use `with`, the close() method can be called. """ - def __init__(self, opt_model: model.Model, solver_type: parameters.SolverType): + def __init__( + self, + opt_model: model.Model, + solver_type: parameters.SolverType, + *, + streamable_init_args: Optional[ + init_arguments.StreamableSolverInitArguments + ] = None, + ): + streamable_init_args = ( + streamable_init_args or init_arguments.StreamableSolverInitArguments() + ) self._model = opt_model self._solver_type = solver_type self._update_tracker = self._model.add_update_tracker() @@ -171,7 +193,7 @@ class IncrementalSolver: self._proto_solver = solver.new( solver_type.value, self._model.export_model(), - parameters_pb2.SolverInitializerProto(), + streamable_init_args.to_proto(), ) except StatusNotOk as e: raise _status_not_ok_to_exception(e) from None diff --git a/ortools/math_opt/python/solve_gurobi_test.py b/ortools/math_opt/python/solve_gurobi_test.py index 19b201e008..364cd209c4 100644 --- a/ortools/math_opt/python/solve_gurobi_test.py +++ b/ortools/math_opt/python/solve_gurobi_test.py @@ -19,8 +19,10 @@ machine. """ from absl.testing import absltest +from ortools.gurobi.isv.secret import gurobi_test_isv_key from ortools.math_opt.python import callback from ortools.math_opt.python import compute_infeasible_subsystem_result +from ortools.math_opt.python import init_arguments from ortools.math_opt.python import model from ortools.math_opt.python import parameters from ortools.math_opt.python import result @@ -29,6 +31,18 @@ from ortools.math_opt.python import solve _Bounds = compute_infeasible_subsystem_result.ModelSubsetBounds +_bad_isv_key = init_arguments.GurobiISVKey( + name="cat", application_name="hat", expiration=10, key="bat" +) + + +def _init_args( + gurobi_key: init_arguments.GurobiISVKey, +) -> init_arguments.StreamableSolverInitArguments: + return init_arguments.StreamableSolverInitArguments( + gurobi=init_arguments.StreamableGurobiInitArguments(isv_key=gurobi_key) + ) + class SolveTest(absltest.TestCase): @@ -89,6 +103,95 @@ class SolveTest(absltest.TestCase): ) self.assertEmpty(iis.infeasible_subsystem.variable_integrality) + def test_solve_valid_isv_success(self): + mod = model.Model() + x = mod.add_binary_variable() + mod.maximize(x) + res = solve.solve( + mod, + parameters.SolverType.GUROBI, + streamable_init_args=_init_args( + gurobi_test_isv_key.google_test_isv_key_placeholder() + ), + ) + self.assertEqual( + res.termination.reason, + result.TerminationReason.OPTIMAL, + msg=res.termination, + ) + self.assertAlmostEqual(1.0, res.termination.objective_bounds.primal_bound) + + def test_solve_wrong_isv_error(self): + mod = model.Model() + x = mod.add_binary_variable() + mod.maximize(x) + with self.assertRaisesRegex( + ValueError, "failed to create Gurobi primary environment with ISV key" + ): + solve.solve( + mod, + parameters.SolverType.GUROBI, + streamable_init_args=_init_args(_bad_isv_key), + ) + + def test_incremental_solver_valid_isv_success(self): + mod = model.Model() + x = mod.add_binary_variable() + mod.maximize(x) + s = solve.IncrementalSolver( + mod, + parameters.SolverType.GUROBI, + streamable_init_args=_init_args( + gurobi_test_isv_key.google_test_isv_key_placeholder() + ), + ) + res = s.solve() + self.assertEqual( + res.termination.reason, + result.TerminationReason.OPTIMAL, + msg=res.termination, + ) + self.assertAlmostEqual(1.0, res.termination.objective_bounds.primal_bound) + + def test_incremental_solver_wrong_isv_error(self): + mod = model.Model() + x = mod.add_binary_variable() + mod.maximize(x) + with self.assertRaisesRegex( + ValueError, "failed to create Gurobi primary environment with ISV key" + ): + solve.IncrementalSolver( + mod, + parameters.SolverType.GUROBI, + streamable_init_args=_init_args(_bad_isv_key), + ) + + def test_compute_infeasible_subsystem_valid_isv_success(self): + mod = model.Model() + x = mod.add_binary_variable() + mod.add_linear_constraint(x >= 3.0) + res = solve.compute_infeasible_subsystem( + mod, + parameters.SolverType.GUROBI, + streamable_init_args=_init_args( + gurobi_test_isv_key.google_test_isv_key_placeholder() + ), + ) + self.assertEqual(res.feasibility, result.FeasibilityStatus.INFEASIBLE) + + def test_compute_infeasible_subsystem_wrong_isv_error(self): + mod = model.Model() + x = mod.add_binary_variable() + mod.add_linear_constraint(x >= 3.0) + with self.assertRaisesRegex( + ValueError, "failed to create Gurobi primary environment with ISV key" + ): + solve.compute_infeasible_subsystem( + mod, + parameters.SolverType.GUROBI, + streamable_init_args=_init_args(_bad_isv_key), + ) + if __name__ == "__main__": absltest.main() diff --git a/ortools/math_opt/solvers/BUILD.bazel b/ortools/math_opt/solvers/BUILD.bazel index 5f5f326fab..bd121a46ff 100644 --- a/ortools/math_opt/solvers/BUILD.bazel +++ b/ortools/math_opt/solvers/BUILD.bazel @@ -11,8 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") load("@rules_python//python:proto.bzl", "py_proto_library") package(default_visibility = ["//ortools/math_opt:__subpackages__"]) @@ -312,9 +312,9 @@ cc_test( cc_test( name = "cp_sat_solver_test", + timeout = "eternal", srcs = ["cp_sat_solver_test.cc"], shard_count = 10, - timeout = "eternal", deps = [ ":cp_sat_solver", "//ortools/base:gmock_main", diff --git a/ortools/math_opt/solvers/highs_solver.cc b/ortools/math_opt/solvers/highs_solver.cc index 0aed905cb3..c314f85b3e 100644 --- a/ortools/math_opt/solvers/highs_solver.cc +++ b/ortools/math_opt/solvers/highs_solver.cc @@ -66,6 +66,7 @@ #include "ortools/math_opt/solvers/message_callback_data.h" #include "ortools/util/solve_interrupter.h" #include "ortools/util/status_macros.h" +#include "simplex/SimplexConst.h" #include "util/HighsInt.h" namespace operations_research::math_opt { diff --git a/ortools/math_opt/storage/BUILD.bazel b/ortools/math_opt/storage/BUILD.bazel index 9119a4137d..c8c950cfdb 100644 --- a/ortools/math_opt/storage/BUILD.bazel +++ b/ortools/math_opt/storage/BUILD.bazel @@ -182,11 +182,9 @@ cc_library( ":linear_constraint_storage", ":model_storage_types", ":objective_storage", - ":sparse_matrix", ":update_trackers", ":variable_storage", "//ortools/base:intops", - "//ortools/base:map_util", "//ortools/base:status_macros", "//ortools/math_opt:model_cc_proto", "//ortools/math_opt:model_update_cc_proto", diff --git a/ortools/math_opt/storage/model_storage.cc b/ortools/math_opt/storage/model_storage.cc index 8799d28054..bc23f36b07 100644 --- a/ortools/math_opt/storage/model_storage.cc +++ b/ortools/math_opt/storage/model_storage.cc @@ -21,15 +21,12 @@ #include #include -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "ortools/base/map_util.h" #include "ortools/base/status_macros.h" #include "ortools/base/strong_int.h" #include "ortools/math_opt/core/model_summary.h" @@ -41,7 +38,6 @@ #include "ortools/math_opt/sparse_containers.pb.h" #include "ortools/math_opt/storage/iterators.h" #include "ortools/math_opt/storage/linear_constraint_storage.h" -#include "ortools/math_opt/storage/sparse_matrix.h" #include "ortools/math_opt/storage/update_trackers.h" #include "ortools/math_opt/storage/variable_storage.h" #include "ortools/math_opt/validators/model_validator.h" @@ -84,19 +80,21 @@ absl::StatusOr> ModelStorage::FromModelProto( model_proto.linear_constraint_matrix()); // Add quadratic constraints. - storage->quadratic_constraints_.AddConstraints( + storage->copyable_data_.quadratic_constraints.AddConstraints( model_proto.quadratic_constraints()); // Add SOC constraints. - storage->soc_constraints_.AddConstraints( + storage->copyable_data_.soc_constraints.AddConstraints( model_proto.second_order_cone_constraints()); // Add SOS constraints. - storage->sos1_constraints_.AddConstraints(model_proto.sos1_constraints()); - storage->sos2_constraints_.AddConstraints(model_proto.sos2_constraints()); + storage->copyable_data_.sos1_constraints.AddConstraints( + model_proto.sos1_constraints()); + storage->copyable_data_.sos2_constraints.AddConstraints( + model_proto.sos2_constraints()); // Add indicator constraints. - storage->indicator_constraints_.AddConstraints( + storage->copyable_data_.indicator_constraints.AddConstraints( model_proto.indicator_constraints()); return storage; @@ -147,42 +145,22 @@ void ModelStorage::UpdateLinearConstraintCoefficients( std::unique_ptr ModelStorage::Clone( const std::optional new_name) const { - ModelProto model_proto = ExportModel(); + // We leverage the private copy constructor that copies copyable_data_ but not + // update_trackers_ here. + std::unique_ptr clone = + absl::WrapUnique(new ModelStorage(*this)); if (new_name.has_value()) { - model_proto.set_name(std::string(*new_name)); + clone->copyable_data_.name = *new_name; } - absl::StatusOr> clone = - ModelStorage::FromModelProto(model_proto); - // Unless there is a very serious bug, a model exported by ExportModel() - // should always be valid. - CHECK_OK(clone.status()); - - // Update the next ids so that the clone does not reused any deleted id from - // the original. - clone.value()->ensure_next_variable_id_at_least(next_variable_id()); - clone.value()->ensure_next_auxiliary_objective_id_at_least( - next_auxiliary_objective_id()); - clone.value()->ensure_next_linear_constraint_id_at_least( - next_linear_constraint_id()); - clone.value()->ensure_next_constraint_id_at_least( - next_constraint_id()); - clone.value()->ensure_next_constraint_id_at_least( - next_constraint_id()); - clone.value()->ensure_next_constraint_id_at_least( - next_constraint_id()); - clone.value()->ensure_next_constraint_id_at_least( - next_constraint_id()); - clone.value()->ensure_next_constraint_id_at_least( - next_constraint_id()); - - return std::move(clone).value(); + return clone; } VariableId ModelStorage::AddVariable(const double lower_bound, const double upper_bound, const bool is_integer, const absl::string_view name) { - return variables_.Add(lower_bound, upper_bound, is_integer, name); + return copyable_data_.variables.Add(lower_bound, upper_bound, is_integer, + name); } void ModelStorage::AddVariables(const VariablesProto& variables) { @@ -200,39 +178,39 @@ void ModelStorage::AddVariables(const VariablesProto& variables) { } void ModelStorage::DeleteVariable(const VariableId id) { - CHECK(variables_.contains(id)); + CHECK(copyable_data_.variables.contains(id)); const auto& trackers = update_trackers_.GetUpdatedTrackers(); // Reuse output of GetUpdatedTrackers() only once to ensure a consistent view, // do not call UpdateAndGetLinearConstraintDiffs() etc. - objectives_.DeleteVariable( + copyable_data_.objectives.DeleteVariable( id, MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_objective>(trackers)); - linear_constraints_.DeleteVariable( + copyable_data_.linear_constraints.DeleteVariable( id, MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_linear_constraints>( trackers)); - quadratic_constraints_.DeleteVariable(id); - soc_constraints_.DeleteVariable(id); - sos1_constraints_.DeleteVariable(id); - sos2_constraints_.DeleteVariable(id); - indicator_constraints_.DeleteVariable(id); - variables_.Delete( + copyable_data_.quadratic_constraints.DeleteVariable(id); + copyable_data_.soc_constraints.DeleteVariable(id); + copyable_data_.sos1_constraints.DeleteVariable(id); + copyable_data_.sos2_constraints.DeleteVariable(id); + copyable_data_.indicator_constraints.DeleteVariable(id); + copyable_data_.variables.Delete( id, MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_variables>(trackers)); } std::vector ModelStorage::variables() const { - return variables_.Variables(); + return copyable_data_.variables.Variables(); } std::vector ModelStorage::SortedVariables() const { - return variables_.SortedVariables(); + return copyable_data_.variables.SortedVariables(); } LinearConstraintId ModelStorage::AddLinearConstraint( const double lower_bound, const double upper_bound, const absl::string_view name) { - return linear_constraints_.Add(lower_bound, upper_bound, name); + return copyable_data_.linear_constraints.Add(lower_bound, upper_bound, name); } void ModelStorage::AddLinearConstraints( @@ -253,16 +231,17 @@ void ModelStorage::AddLinearConstraints( } void ModelStorage::DeleteLinearConstraint(const LinearConstraintId id) { - CHECK(linear_constraints_.contains(id)); - linear_constraints_.Delete(id, UpdateAndGetLinearConstraintDiffs()); + CHECK(copyable_data_.linear_constraints.contains(id)); + copyable_data_.linear_constraints.Delete(id, + UpdateAndGetLinearConstraintDiffs()); } std::vector ModelStorage::LinearConstraints() const { - return linear_constraints_.LinearConstraints(); + return copyable_data_.linear_constraints.LinearConstraints(); } std::vector ModelStorage::SortedLinearConstraints() const { - return linear_constraints_.SortedLinearConstraints(); + return copyable_data_.linear_constraints.SortedLinearConstraints(); } void ModelStorage::AddAuxiliaryObjectives( @@ -285,23 +264,26 @@ void ModelStorage::AddAuxiliaryObjectives( // tries to create a very long RepeatedField. ModelProto ModelStorage::ExportModel(const bool remove_names) const { ModelProto result; - result.set_name(name_); - *result.mutable_variables() = variables_.Proto(); + result.set_name(copyable_data_.name); + *result.mutable_variables() = copyable_data_.variables.Proto(); { - auto [primary, auxiliary] = objectives_.Proto(); + auto [primary, auxiliary] = copyable_data_.objectives.Proto(); *result.mutable_objective() = std::move(primary); *result.mutable_auxiliary_objectives() = std::move(auxiliary); } { - auto [constraints, matrix] = linear_constraints_.Proto(); + auto [constraints, matrix] = copyable_data_.linear_constraints.Proto(); *result.mutable_linear_constraints() = std::move(constraints); *result.mutable_linear_constraint_matrix() = std::move(matrix); } - *result.mutable_quadratic_constraints() = quadratic_constraints_.Proto(); - *result.mutable_second_order_cone_constraints() = soc_constraints_.Proto(); - *result.mutable_sos1_constraints() = sos1_constraints_.Proto(); - *result.mutable_sos2_constraints() = sos2_constraints_.Proto(); - *result.mutable_indicator_constraints() = indicator_constraints_.Proto(); + *result.mutable_quadratic_constraints() = + copyable_data_.quadratic_constraints.Proto(); + *result.mutable_second_order_cone_constraints() = + copyable_data_.soc_constraints.Proto(); + *result.mutable_sos1_constraints() = copyable_data_.sos1_constraints.Proto(); + *result.mutable_sos2_constraints() = copyable_data_.sos2_constraints.Proto(); + *result.mutable_indicator_constraints() = + copyable_data_.indicator_constraints.Proto(); // Performance can be improved when remove_names is true by just not // extracting the names above instead of clearing them below, but this will // be more code, see discussion on cl/549469633 and prototype in cl/549369764. @@ -317,15 +299,19 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate( // We must detect the empty case to prevent unneeded copies and merging in // ExportModelUpdate(). - if (storage.variables_.diff_is_empty(dirty_variables) && - storage.objectives_.diff_is_empty(dirty_objective) && - storage.linear_constraints_.diff_is_empty(dirty_linear_constraints) && - storage.quadratic_constraints_.diff_is_empty( + if (storage.copyable_data_.variables.diff_is_empty(dirty_variables) && + storage.copyable_data_.objectives.diff_is_empty(dirty_objective) && + storage.copyable_data_.linear_constraints.diff_is_empty( + dirty_linear_constraints) && + storage.copyable_data_.quadratic_constraints.diff_is_empty( dirty_quadratic_constraints) && - storage.soc_constraints_.diff_is_empty(dirty_soc_constraints) && - storage.sos1_constraints_.diff_is_empty(dirty_sos1_constraints) && - storage.sos2_constraints_.diff_is_empty(dirty_sos2_constraints) && - storage.indicator_constraints_.diff_is_empty( + storage.copyable_data_.soc_constraints.diff_is_empty( + dirty_soc_constraints) && + storage.copyable_data_.sos1_constraints.diff_is_empty( + dirty_sos1_constraints) && + storage.copyable_data_.sos2_constraints.diff_is_empty( + dirty_sos2_constraints) && + storage.copyable_data_.indicator_constraints.diff_is_empty( dirty_indicator_constraints)) { return std::nullopt; } @@ -335,18 +321,19 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate( // Variable/constraint deletions. { VariableStorage::UpdateResult variable_update = - storage.variables_.Update(dirty_variables); + storage.copyable_data_.variables.Update(dirty_variables); *result.mutable_deleted_variable_ids() = std::move(variable_update.deleted); *result.mutable_variable_updates() = std::move(variable_update.updates); *result.mutable_new_variables() = std::move(variable_update.creates); } const std::vector new_variables = - storage.variables_.VariablesFrom(dirty_variables.checkpoint); + storage.copyable_data_.variables.VariablesFrom( + dirty_variables.checkpoint); // Linear constraint updates { LinearConstraintStorage::UpdateResult lin_con_update = - storage.linear_constraints_.Update( + storage.copyable_data_.linear_constraints.Update( dirty_linear_constraints, dirty_variables.deleted, new_variables); *result.mutable_deleted_linear_constraint_ids() = std::move(lin_con_update.deleted); @@ -360,25 +347,27 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate( // Quadratic constraint updates *result.mutable_quadratic_constraint_updates() = - storage.quadratic_constraints_.Update(dirty_quadratic_constraints); + storage.copyable_data_.quadratic_constraints.Update( + dirty_quadratic_constraints); // Second-order cone constraint updates *result.mutable_second_order_cone_constraint_updates() = - storage.soc_constraints_.Update(dirty_soc_constraints); + storage.copyable_data_.soc_constraints.Update(dirty_soc_constraints); // SOS constraint updates *result.mutable_sos1_constraint_updates() = - storage.sos1_constraints_.Update(dirty_sos1_constraints); + storage.copyable_data_.sos1_constraints.Update(dirty_sos1_constraints); *result.mutable_sos2_constraint_updates() = - storage.sos2_constraints_.Update(dirty_sos2_constraints); + storage.copyable_data_.sos2_constraints.Update(dirty_sos2_constraints); // Indicator constraint updates *result.mutable_indicator_constraint_updates() = - storage.indicator_constraints_.Update(dirty_indicator_constraints); + storage.copyable_data_.indicator_constraints.Update( + dirty_indicator_constraints); // Update the objective { - auto [primary, auxiliary] = storage.objectives_.Update( + auto [primary, auxiliary] = storage.copyable_data_.objectives.Update( dirty_objective, dirty_variables.deleted, new_variables); *result.mutable_objective_updates() = std::move(primary); *result.mutable_auxiliary_objectives_updates() = std::move(auxiliary); @@ -392,25 +381,29 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate( void ModelStorage::UpdateTrackerData::AdvanceCheckpoint( const ModelStorage& storage) { - storage.variables_.AdvanceCheckpointInDiff(dirty_variables); - storage.objectives_.AdvanceCheckpointInDiff(dirty_variables.checkpoint, - dirty_objective); - storage.linear_constraints_.AdvanceCheckpointInDiff( + storage.copyable_data_.variables.AdvanceCheckpointInDiff(dirty_variables); + storage.copyable_data_.objectives.AdvanceCheckpointInDiff( + dirty_variables.checkpoint, dirty_objective); + storage.copyable_data_.linear_constraints.AdvanceCheckpointInDiff( dirty_variables.checkpoint, dirty_linear_constraints); - storage.quadratic_constraints_.AdvanceCheckpointInDiff( + storage.copyable_data_.quadratic_constraints.AdvanceCheckpointInDiff( dirty_quadratic_constraints); - storage.soc_constraints_.AdvanceCheckpointInDiff(dirty_soc_constraints); - storage.sos1_constraints_.AdvanceCheckpointInDiff(dirty_sos1_constraints); - storage.sos2_constraints_.AdvanceCheckpointInDiff(dirty_sos2_constraints); - storage.indicator_constraints_.AdvanceCheckpointInDiff( + storage.copyable_data_.soc_constraints.AdvanceCheckpointInDiff( + dirty_soc_constraints); + storage.copyable_data_.sos1_constraints.AdvanceCheckpointInDiff( + dirty_sos1_constraints); + storage.copyable_data_.sos2_constraints.AdvanceCheckpointInDiff( + dirty_sos2_constraints); + storage.copyable_data_.indicator_constraints.AdvanceCheckpointInDiff( dirty_indicator_constraints); } UpdateTrackerId ModelStorage::NewUpdateTracker() { return update_trackers_.NewUpdateTracker( - variables_, objectives_, linear_constraints_, quadratic_constraints_, - soc_constraints_, sos1_constraints_, sos2_constraints_, - indicator_constraints_); + copyable_data_.variables, copyable_data_.objectives, + copyable_data_.linear_constraints, copyable_data_.quadratic_constraints, + copyable_data_.soc_constraints, copyable_data_.sos1_constraints, + copyable_data_.sos2_constraints, copyable_data_.indicator_constraints); } void ModelStorage::DeleteUpdateTracker(const UpdateTrackerId update_tracker) { @@ -438,50 +431,50 @@ absl::Status ModelStorage::ApplyUpdateProto( RETURN_IF_ERROR(summary.variables.Insert(id.value(), variable_name(id))) << "invalid variable id in model"; } - RETURN_IF_ERROR( - summary.variables.SetNextFreeId(variables_.next_id().value())); + RETURN_IF_ERROR(summary.variables.SetNextFreeId( + copyable_data_.variables.next_id().value())); for (const AuxiliaryObjectiveId id : SortedAuxiliaryObjectives()) { RETURN_IF_ERROR( summary.auxiliary_objectives.Insert(id.value(), objective_name(id))) << "invalid auxiliary objective id in model"; } RETURN_IF_ERROR(summary.auxiliary_objectives.SetNextFreeId( - objectives_.next_id().value())); + copyable_data_.objectives.next_id().value())); for (const LinearConstraintId id : SortedLinearConstraints()) { RETURN_IF_ERROR(summary.linear_constraints.Insert( id.value(), linear_constraint_name(id))) << "invalid linear constraint id in model"; } RETURN_IF_ERROR(summary.linear_constraints.SetNextFreeId( - linear_constraints_.next_id().value())); + copyable_data_.linear_constraints.next_id().value())); for (const auto id : SortedConstraints()) { RETURN_IF_ERROR(summary.quadratic_constraints.Insert( - id.value(), quadratic_constraints_.data(id).name)) + id.value(), copyable_data_.quadratic_constraints.data(id).name)) << "invalid quadratic constraint id in model"; } RETURN_IF_ERROR(summary.quadratic_constraints.SetNextFreeId( - quadratic_constraints_.next_id().value())); + copyable_data_.quadratic_constraints.next_id().value())); for (const auto id : SortedConstraints()) { RETURN_IF_ERROR(summary.second_order_cone_constraints.Insert( - id.value(), soc_constraints_.data(id).name)) + id.value(), copyable_data_.soc_constraints.data(id).name)) << "invalid second-order cone constraint id in model"; } RETURN_IF_ERROR(summary.second_order_cone_constraints.SetNextFreeId( - soc_constraints_.next_id().value())); + copyable_data_.soc_constraints.next_id().value())); for (const Sos1ConstraintId id : SortedConstraints()) { RETURN_IF_ERROR(summary.sos1_constraints.Insert( id.value(), constraint_data(id).name())) << "invalid SOS1 constraint id in model"; } RETURN_IF_ERROR(summary.sos1_constraints.SetNextFreeId( - sos1_constraints_.next_id().value())); + copyable_data_.sos1_constraints.next_id().value())); for (const Sos2ConstraintId id : SortedConstraints()) { RETURN_IF_ERROR(summary.sos2_constraints.Insert( id.value(), constraint_data(id).name())) << "invalid SOS2 constraint id in model"; } RETURN_IF_ERROR(summary.sos2_constraints.SetNextFreeId( - sos2_constraints_.next_id().value())); + copyable_data_.sos2_constraints.next_id().value())); for (const IndicatorConstraintId id : SortedConstraints()) { @@ -489,7 +482,7 @@ absl::Status ModelStorage::ApplyUpdateProto( id.value(), constraint_data(id).name)); } RETURN_IF_ERROR(summary.indicator_constraints.SetNextFreeId( - indicator_constraints_.next_id().value())); + copyable_data_.indicator_constraints.next_id().value())); RETURN_IF_ERROR(ValidateModelUpdate(update_proto, summary)) << "update not valid"; @@ -556,15 +549,15 @@ absl::Status ModelStorage::ApplyUpdateProto( AddAuxiliaryObjectives( update_proto.auxiliary_objectives_updates().new_objectives()); AddLinearConstraints(update_proto.new_linear_constraints()); - quadratic_constraints_.AddConstraints( + copyable_data_.quadratic_constraints.AddConstraints( update_proto.quadratic_constraint_updates().new_constraints()); - soc_constraints_.AddConstraints( + copyable_data_.soc_constraints.AddConstraints( update_proto.second_order_cone_constraint_updates().new_constraints()); - sos1_constraints_.AddConstraints( + copyable_data_.sos1_constraints.AddConstraints( update_proto.sos1_constraint_updates().new_constraints()); - sos2_constraints_.AddConstraints( + copyable_data_.sos2_constraints.AddConstraints( update_proto.sos2_constraint_updates().new_constraints()); - indicator_constraints_.AddConstraints( + copyable_data_.indicator_constraints.AddConstraints( update_proto.indicator_constraint_updates().new_constraints()); // Update the primary objective. diff --git a/ortools/math_opt/storage/model_storage.h b/ortools/math_opt/storage/model_storage.h index 372aff7f2c..0c5fc98718 100644 --- a/ortools/math_opt/storage/model_storage.h +++ b/ortools/math_opt/storage/model_storage.h @@ -171,7 +171,6 @@ class ModelStorage { inline explicit ModelStorage(absl::string_view model_name = "", absl::string_view primary_objective_name = ""); - ModelStorage(const ModelStorage&) = delete; ModelStorage& operator=(const ModelStorage&) = delete; // Returns a clone of the model, optionally changing model's name. @@ -183,7 +182,7 @@ class ModelStorage { std::unique_ptr Clone( std::optional new_name = std::nullopt) const; - inline const std::string& name() const { return name_; } + inline const std::string& name() const { return copyable_data_.name; } ////////////////////////////////////////////////////////////////////////////// // Variables @@ -692,6 +691,30 @@ class ModelStorage { dirty_indicator_constraints; }; + // All data that is copied (by the C++ default copy constructor) when using + // Clone(). + struct CopyableData { + CopyableData(const absl::string_view model_name, + const absl::string_view primary_objective_name) + : name(model_name), objectives(/*name=*/primary_objective_name) {} + std::string name; + + VariableStorage variables; + ObjectiveStorage objectives; + LinearConstraintStorage linear_constraints; + + AtomicConstraintStorage quadratic_constraints; + AtomicConstraintStorage soc_constraints; + AtomicConstraintStorage sos1_constraints; + AtomicConstraintStorage sos2_constraints; + AtomicConstraintStorage indicator_constraints; + }; + + // Private copy constructor that copies only copyable_data_, not + // update_trackers_. It is used internally by Clone(). + ModelStorage(const ModelStorage& other) + : copyable_data_(other.copyable_data_) {} + auto UpdateAndGetVariableDiffs() { return MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_variables>( update_trackers_.GetUpdatedTrackers()); @@ -745,18 +768,7 @@ class ModelStorage { template const AtomicConstraintStorage& constraint_storage() const; - std::string name_; - - VariableStorage variables_; - ObjectiveStorage objectives_; - LinearConstraintStorage linear_constraints_; - - AtomicConstraintStorage quadratic_constraints_; - AtomicConstraintStorage soc_constraints_; - AtomicConstraintStorage sos1_constraints_; - AtomicConstraintStorage sos2_constraints_; - AtomicConstraintStorage indicator_constraints_; - + CopyableData copyable_data_; UpdateTrackers update_trackers_; }; @@ -768,7 +780,8 @@ class ModelStorage { ModelStorage::ModelStorage(const absl::string_view model_name, const absl::string_view primary_objective_name) - : name_(model_name), objectives_(primary_objective_name) {} + : copyable_data_(/*model_name=*/model_name, + /*primary_objective_name=*/primary_objective_name) {} //////////////////////////////////////////////////////////////////////////////// // Variables @@ -780,34 +793,37 @@ VariableId ModelStorage::AddVariable(absl::string_view name) { } double ModelStorage::variable_lower_bound(const VariableId id) const { - return variables_.lower_bound(id); + return copyable_data_.variables.lower_bound(id); } double ModelStorage::variable_upper_bound(const VariableId id) const { - return variables_.upper_bound(id); + return copyable_data_.variables.upper_bound(id); } bool ModelStorage::is_variable_integer(VariableId id) const { - return variables_.is_integer(id); + return copyable_data_.variables.is_integer(id); } const std::string& ModelStorage::variable_name(const VariableId id) const { - return variables_.name(id); + return copyable_data_.variables.name(id); } void ModelStorage::set_variable_lower_bound(const VariableId id, const double lower_bound) { - variables_.set_lower_bound(id, lower_bound, UpdateAndGetVariableDiffs()); + copyable_data_.variables.set_lower_bound(id, lower_bound, + UpdateAndGetVariableDiffs()); } void ModelStorage::set_variable_upper_bound(const VariableId id, const double upper_bound) { - variables_.set_upper_bound(id, upper_bound, UpdateAndGetVariableDiffs()); + copyable_data_.variables.set_upper_bound(id, upper_bound, + UpdateAndGetVariableDiffs()); } void ModelStorage::set_variable_is_integer(const VariableId id, const bool is_integer) { - variables_.set_integer(id, is_integer, UpdateAndGetVariableDiffs()); + copyable_data_.variables.set_integer(id, is_integer, + UpdateAndGetVariableDiffs()); } void ModelStorage::set_variable_as_integer(VariableId id) { @@ -818,18 +834,20 @@ void ModelStorage::set_variable_as_continuous(VariableId id) { set_variable_is_integer(id, false); } -int ModelStorage::num_variables() const { return variables_.size(); } +int ModelStorage::num_variables() const { + return static_cast(copyable_data_.variables.size()); +} VariableId ModelStorage::next_variable_id() const { - return variables_.next_id(); + return copyable_data_.variables.next_id(); } void ModelStorage::ensure_next_variable_id_at_least(const VariableId id) { - variables_.ensure_next_id_at_least(id); + copyable_data_.variables.ensure_next_id_at_least(id); } bool ModelStorage::has_variable(const VariableId id) const { - return variables_.contains(id); + return copyable_data_.variables.contains(id); } //////////////////////////////////////////////////////////////////////////////// @@ -843,46 +861,46 @@ LinearConstraintId ModelStorage::AddLinearConstraint(absl::string_view name) { double ModelStorage::linear_constraint_lower_bound( const LinearConstraintId id) const { - return linear_constraints_.lower_bound(id); + return copyable_data_.linear_constraints.lower_bound(id); } double ModelStorage::linear_constraint_upper_bound( const LinearConstraintId id) const { - return linear_constraints_.upper_bound(id); + return copyable_data_.linear_constraints.upper_bound(id); } const std::string& ModelStorage::linear_constraint_name( const LinearConstraintId id) const { - return linear_constraints_.name(id); + return copyable_data_.linear_constraints.name(id); } void ModelStorage::set_linear_constraint_lower_bound( const LinearConstraintId id, const double lower_bound) { - linear_constraints_.set_lower_bound(id, lower_bound, - UpdateAndGetLinearConstraintDiffs()); + copyable_data_.linear_constraints.set_lower_bound( + id, lower_bound, UpdateAndGetLinearConstraintDiffs()); } void ModelStorage::set_linear_constraint_upper_bound( const LinearConstraintId id, const double upper_bound) { - linear_constraints_.set_upper_bound(id, upper_bound, - UpdateAndGetLinearConstraintDiffs()); + copyable_data_.linear_constraints.set_upper_bound( + id, upper_bound, UpdateAndGetLinearConstraintDiffs()); } int ModelStorage::num_linear_constraints() const { - return linear_constraints_.size(); + return static_cast(copyable_data_.linear_constraints.size()); } LinearConstraintId ModelStorage::next_linear_constraint_id() const { - return linear_constraints_.next_id(); + return copyable_data_.linear_constraints.next_id(); } void ModelStorage::ensure_next_linear_constraint_id_at_least( LinearConstraintId id) { - linear_constraints_.ensure_next_id_at_least(id); + copyable_data_.linear_constraints.ensure_next_id_at_least(id); } bool ModelStorage::has_linear_constraint(const LinearConstraintId id) const { - return linear_constraints_.contains(id); + return copyable_data_.linear_constraints.contains(id); } //////////////////////////////////////////////////////////////////////////////// @@ -891,34 +909,35 @@ bool ModelStorage::has_linear_constraint(const LinearConstraintId id) const { double ModelStorage::linear_constraint_coefficient( LinearConstraintId constraint, VariableId variable) const { - return linear_constraints_.matrix().get(constraint, variable); + return copyable_data_.linear_constraints.matrix().get(constraint, variable); } bool ModelStorage::is_linear_constraint_coefficient_nonzero( LinearConstraintId constraint, VariableId variable) const { - return linear_constraints_.matrix().contains(constraint, variable); + return copyable_data_.linear_constraints.matrix().contains(constraint, + variable); } void ModelStorage::set_linear_constraint_coefficient( const LinearConstraintId constraint, const VariableId variable, const double value) { - linear_constraints_.set_term(constraint, variable, value, - UpdateAndGetLinearConstraintDiffs()); + copyable_data_.linear_constraints.set_term( + constraint, variable, value, UpdateAndGetLinearConstraintDiffs()); } -std::vector> +std::vector > ModelStorage::linear_constraint_matrix() const { - return linear_constraints_.matrix().Terms(); + return copyable_data_.linear_constraints.matrix().Terms(); } std::vector ModelStorage::variables_in_linear_constraint( LinearConstraintId constraint) const { - return linear_constraints_.matrix().row(constraint); + return copyable_data_.linear_constraints.matrix().row(constraint); } std::vector ModelStorage::linear_constraints_with_variable( VariableId variable) const { - return linear_constraints_.matrix().column(variable); + return copyable_data_.linear_constraints.matrix().column(variable); } //////////////////////////////////////////////////////////////////////////////// @@ -926,47 +945,49 @@ std::vector ModelStorage::linear_constraints_with_variable( //////////////////////////////////////////////////////////////////////////////// bool ModelStorage::is_maximize(const ObjectiveId id) const { - return objectives_.maximize(id); + return copyable_data_.objectives.maximize(id); } int64_t ModelStorage::objective_priority(const ObjectiveId id) const { - return objectives_.priority(id); + return copyable_data_.objectives.priority(id); } double ModelStorage::objective_offset(const ObjectiveId id) const { - return objectives_.offset(id); + return copyable_data_.objectives.offset(id); } double ModelStorage::linear_objective_coefficient( const ObjectiveId id, const VariableId variable) const { - return objectives_.linear_term(id, variable); + return copyable_data_.objectives.linear_term(id, variable); } double ModelStorage::quadratic_objective_coefficient( const ObjectiveId id, const VariableId first_variable, const VariableId second_variable) const { - return objectives_.quadratic_term(id, first_variable, second_variable); + return copyable_data_.objectives.quadratic_term(id, first_variable, + second_variable); } bool ModelStorage::is_linear_objective_coefficient_nonzero( const ObjectiveId id, const VariableId variable) const { - return objectives_.linear_terms(id).contains(variable); + return copyable_data_.objectives.linear_terms(id).contains(variable); } bool ModelStorage::is_quadratic_objective_coefficient_nonzero( const ObjectiveId id, const VariableId first_variable, const VariableId second_variable) const { - return objectives_.quadratic_terms(id).get(first_variable, second_variable) != - 0.0; + return copyable_data_.objectives.quadratic_terms(id).get( + first_variable, second_variable) != 0.0; } const std::string& ModelStorage::objective_name(const ObjectiveId id) const { - return objectives_.name(id); + return copyable_data_.objectives.name(id); } void ModelStorage::set_is_maximize(const ObjectiveId id, const bool is_maximize) { - objectives_.set_maximize(id, is_maximize, UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.set_maximize(id, is_maximize, + UpdateAndGetObjectiveDiffs()); } void ModelStorage::set_maximize(const ObjectiveId id) { @@ -979,49 +1000,50 @@ void ModelStorage::set_minimize(const ObjectiveId id) { void ModelStorage::set_objective_priority(const ObjectiveId id, const int64_t value) { - objectives_.set_priority(id, value, UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.set_priority(id, value, + UpdateAndGetObjectiveDiffs()); } void ModelStorage::set_objective_offset(const ObjectiveId id, const double value) { - objectives_.set_offset(id, value, UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.set_offset(id, value, UpdateAndGetObjectiveDiffs()); } void ModelStorage::set_linear_objective_coefficient(const ObjectiveId id, const VariableId variable, const double value) { - objectives_.set_linear_term(id, variable, value, - UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.set_linear_term(id, variable, value, + UpdateAndGetObjectiveDiffs()); } void ModelStorage::set_quadratic_objective_coefficient( const ObjectiveId id, const VariableId first_variable, const VariableId second_variable, const double value) { - objectives_.set_quadratic_term(id, first_variable, second_variable, value, - UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.set_quadratic_term( + id, first_variable, second_variable, value, UpdateAndGetObjectiveDiffs()); } void ModelStorage::clear_objective(const ObjectiveId id) { - objectives_.Clear(id, UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.Clear(id, UpdateAndGetObjectiveDiffs()); } const absl::flat_hash_map& ModelStorage::linear_objective( const ObjectiveId id) const { - return objectives_.linear_terms(id); + return copyable_data_.objectives.linear_terms(id); } int64_t ModelStorage::num_linear_objective_terms(const ObjectiveId id) const { - return objectives_.linear_terms(id).size(); + return copyable_data_.objectives.linear_terms(id).size(); } int64_t ModelStorage::num_quadratic_objective_terms( const ObjectiveId id) const { - return objectives_.quadratic_terms(id).nonzeros(); + return copyable_data_.objectives.quadratic_terms(id).nonzeros(); } -std::vector> +std::vector > ModelStorage::quadratic_objective_terms(const ObjectiveId id) const { - return objectives_.quadratic_terms(id).Terms(); + return copyable_data_.objectives.quadratic_terms(id).Terms(); } //////////////////////////////////////////////////////////////////////////////// @@ -1030,38 +1052,38 @@ ModelStorage::quadratic_objective_terms(const ObjectiveId id) const { AuxiliaryObjectiveId ModelStorage::AddAuxiliaryObjective( const int64_t priority, const absl::string_view name) { - return objectives_.AddAuxiliaryObjective(priority, name); + return copyable_data_.objectives.AddAuxiliaryObjective(priority, name); } void ModelStorage::DeleteAuxiliaryObjective(const AuxiliaryObjectiveId id) { - objectives_.Delete(id, UpdateAndGetObjectiveDiffs()); + copyable_data_.objectives.Delete(id, UpdateAndGetObjectiveDiffs()); } int ModelStorage::num_auxiliary_objectives() const { - return static_cast(objectives_.num_auxiliary_objectives()); + return static_cast(copyable_data_.objectives.num_auxiliary_objectives()); } AuxiliaryObjectiveId ModelStorage::next_auxiliary_objective_id() const { - return objectives_.next_id(); + return copyable_data_.objectives.next_id(); } void ModelStorage::ensure_next_auxiliary_objective_id_at_least( const AuxiliaryObjectiveId id) { - objectives_.ensure_next_id_at_least(id); + copyable_data_.objectives.ensure_next_id_at_least(id); } bool ModelStorage::has_auxiliary_objective( const AuxiliaryObjectiveId id) const { - return objectives_.contains(id); + return copyable_data_.objectives.contains(id); } std::vector ModelStorage::AuxiliaryObjectives() const { - return objectives_.AuxiliaryObjectives(); + return copyable_data_.objectives.AuxiliaryObjectives(); } std::vector ModelStorage::SortedAuxiliaryObjectives() const { - return objectives_.SortedAuxiliaryObjectives(); + return copyable_data_.objectives.SortedAuxiliaryObjectives(); } //////////////////////////////////////////////////////////////////////////////// @@ -1162,13 +1184,13 @@ std::vector ModelStorage::VariablesInConstraint( template <> inline AtomicConstraintStorage& ModelStorage::constraint_storage() { - return quadratic_constraints_; + return copyable_data_.quadratic_constraints; } template <> inline const AtomicConstraintStorage& ModelStorage::constraint_storage() const { - return quadratic_constraints_; + return copyable_data_.quadratic_constraints; } template <> @@ -1184,13 +1206,13 @@ constexpr typename AtomicConstraintStorage::Diff template <> inline AtomicConstraintStorage& ModelStorage::constraint_storage() { - return soc_constraints_; + return copyable_data_.soc_constraints; } template <> inline const AtomicConstraintStorage& ModelStorage::constraint_storage() const { - return soc_constraints_; + return copyable_data_.soc_constraints; } template <> @@ -1206,13 +1228,13 @@ constexpr typename AtomicConstraintStorage::Diff template <> inline AtomicConstraintStorage& ModelStorage::constraint_storage() { - return sos1_constraints_; + return copyable_data_.sos1_constraints; } template <> inline const AtomicConstraintStorage& ModelStorage::constraint_storage() const { - return sos1_constraints_; + return copyable_data_.sos1_constraints; } template <> @@ -1228,13 +1250,13 @@ constexpr typename AtomicConstraintStorage::Diff template <> inline AtomicConstraintStorage& ModelStorage::constraint_storage() { - return sos2_constraints_; + return copyable_data_.sos2_constraints; } template <> inline const AtomicConstraintStorage& ModelStorage::constraint_storage() const { - return sos2_constraints_; + return copyable_data_.sos2_constraints; } template <> @@ -1250,13 +1272,13 @@ constexpr typename AtomicConstraintStorage::Diff template <> inline AtomicConstraintStorage& ModelStorage::constraint_storage() { - return indicator_constraints_; + return copyable_data_.indicator_constraints; } template <> inline const AtomicConstraintStorage& ModelStorage::constraint_storage() const { - return indicator_constraints_; + return copyable_data_.indicator_constraints; } template <>