math_opt: Export from google3

This commit is contained in:
Corentin Le Molgat
2024-11-12 13:55:16 +01:00
parent 8a5976b99f
commit cac5698cd2
17 changed files with 926 additions and 291 deletions

View File

@@ -11,8 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library")
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
load("@rules_python//python:proto.bzl", "py_proto_library")
package(default_visibility = ["//visibility:public"])

View File

@@ -28,6 +28,7 @@ py_library(
":errors",
":expressions",
":hash_model_storage",
":init_arguments",
":message_callback",
":model",
":model_parameters",
@@ -160,6 +161,7 @@ py_library(
":callback",
":compute_infeasible_subsystem_result",
":errors",
":init_arguments",
":message_callback",
":model",
":model_parameters",
@@ -210,3 +212,12 @@ py_library(
srcs = ["errors.py"],
deps = ["//ortools/math_opt:rpc_py_pb2"],
)
py_library(
name = "init_arguments",
srcs = ["init_arguments.py"],
deps = [
"//ortools/math_opt:parameters_py_pb2",
"//ortools/math_opt/solvers:gurobi_py_pb2",
],
)

View File

@@ -0,0 +1,180 @@
# Copyright 2010-2024 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configures the instantiation of the underlying solver."""
import dataclasses
from typing import Optional
from ortools.math_opt import parameters_pb2
from ortools.math_opt.solvers import gurobi_pb2
@dataclasses.dataclass
class StreamableGScipInitArguments:
"""Streamable GScip specific parameters for solver instantiation."""
@dataclasses.dataclass(frozen=True)
class GurobiISVKey:
"""The Gurobi ISV key, an alternative to license files.
Contact Gurobi for details.
Attributes:
name: A string, typically a company/organization.
application_name: A string, typically a project.
expiration: An int, a value of 0 indicates no expiration.
key: A string, the secret.
"""
name: str = ""
application_name: str = ""
expiration: int = 0
key: str = ""
def to_proto(self) -> gurobi_pb2.GurobiInitializerProto.ISVKey:
"""Returns a protocol buffer equivalent of this."""
return gurobi_pb2.GurobiInitializerProto.ISVKey(
name=self.name,
application_name=self.application_name,
expiration=self.expiration,
key=self.key,
)
def gurobi_isv_key_from_proto(
proto: gurobi_pb2.GurobiInitializerProto.ISVKey,
) -> GurobiISVKey:
"""Returns an equivalent GurobiISVKey to the input proto."""
return GurobiISVKey(
name=proto.name,
application_name=proto.application_name,
expiration=proto.expiration,
key=proto.key,
)
@dataclasses.dataclass
class StreamableGurobiInitArguments:
"""Streamable Gurobi specific parameters for solver instantiation."""
isv_key: Optional[GurobiISVKey] = None
def to_proto(self) -> gurobi_pb2.GurobiInitializerProto:
"""Returns a protocol buffer equivalent of this."""
return gurobi_pb2.GurobiInitializerProto(
isv_key=self.isv_key.to_proto() if self.isv_key else None
)
def streamable_gurobi_init_arguments_from_proto(
proto: gurobi_pb2.GurobiInitializerProto,
) -> StreamableGurobiInitArguments:
"""Returns an equivalent StreamableGurobiInitArguments to the input proto."""
result = StreamableGurobiInitArguments()
if proto.HasField("isv_key"):
result.isv_key = gurobi_isv_key_from_proto(proto.isv_key)
return result
@dataclasses.dataclass
class StreamableGlopInitArguments:
"""Streamable Glop specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableCpSatInitArguments:
"""Streamable CP-SAT specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamablePdlpInitArguments:
"""Streamable Pdlp specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableGlpkInitArguments:
"""Streamable GLPK specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableOsqpInitArguments:
"""Streamable OSQP specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableEcosInitArguments:
"""Streamable Ecos specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableScsInitArguments:
"""Streamable Scs specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableHighsInitArguments:
"""Streamable Highs specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableSantoriniInitArguments:
"""Streamable Santorini specific parameters for solver instantiation."""
@dataclasses.dataclass
class StreamableSolverInitArguments:
"""Solver initialization parameters that can be sent to another process.
Attributes:
gscip: Initialization parameters specific to GScip.
gurobi: Initialization parameters specific to Gurobi.
glop: Initialization parameters specific to GLOP.
cp_sat: Initialization parameters specific to CP-SAT.
pdlp: Initialization parameters specific to PDLP.
glpk: Initialization parameters specific to GLPK.
osqp: Initialization parameters specific to OSQP.
ecos: Initialization parameters specific to ECOS.
scs: Initialization parameters specific to SCS.
highs: Initialization parameters specific to HiGHS.
santorini: Initialization parameters specific to Santorini.
"""
gscip: Optional[StreamableGScipInitArguments] = None
gurobi: Optional[StreamableGurobiInitArguments] = None
glop: Optional[StreamableGlopInitArguments] = None
cp_sat: Optional[StreamableCpSatInitArguments] = None
pdlp: Optional[StreamablePdlpInitArguments] = None
glpk: Optional[StreamableGlpkInitArguments] = None
osqp: Optional[StreamableOsqpInitArguments] = None
ecos: Optional[StreamableEcosInitArguments] = None
scs: Optional[StreamableScsInitArguments] = None
highs: Optional[StreamableHighsInitArguments] = None
santorini: Optional[StreamableSantoriniInitArguments] = None
def to_proto(self) -> parameters_pb2.SolverInitializerProto:
"""Returns a protocol buffer equivalent of this."""
return parameters_pb2.SolverInitializerProto(
gurobi=self.gurobi.to_proto() if self.gurobi else None
)
def streamable_solver_init_arguments_from_proto(
proto: parameters_pb2.SolverInitializerProto,
) -> StreamableSolverInitArguments:
"""Returns an equivalent StreamableSolverInitArguments to the input proto."""
result = StreamableSolverInitArguments()
if proto.HasField("gurobi"):
result.gurobi = streamable_gurobi_init_arguments_from_proto(proto.gurobi)
return result

View File

@@ -0,0 +1,102 @@
#!/usr/bin/env python3
# Copyright 2010-2024 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from ortools.math_opt import parameters_pb2
from ortools.math_opt.python import init_arguments
from ortools.math_opt.python.testing import compare_proto
from ortools.math_opt.solvers import gurobi_pb2
class GurobiISVKeyTest(absltest.TestCase, compare_proto.MathOptProtoAssertions):
def test_proto_conversions(self) -> None:
isv = init_arguments.GurobiISVKey(
name="cat", application_name="hat", expiration=4, key="bat"
)
proto_isv = gurobi_pb2.GurobiInitializerProto.ISVKey(
name="cat", application_name="hat", expiration=4, key="bat"
)
self.assert_protos_equiv(isv.to_proto(), proto_isv)
self.assertEqual(init_arguments.gurobi_isv_key_from_proto(proto_isv), isv)
class StreamableGurobiInitArgumentsTest(
absltest.TestCase, compare_proto.MathOptProtoAssertions
):
def test_proto_conversions_isv_key_set(self) -> None:
init = init_arguments.StreamableGurobiInitArguments(
isv_key=init_arguments.GurobiISVKey(
name="cat", application_name="hat", expiration=4, key="bat"
)
)
proto_init = gurobi_pb2.GurobiInitializerProto(
isv_key=gurobi_pb2.GurobiInitializerProto.ISVKey(
name="cat", application_name="hat", expiration=4, key="bat"
)
)
self.assert_protos_equiv(init.to_proto(), proto_init)
self.assertEqual(
init_arguments.streamable_gurobi_init_arguments_from_proto(proto_init),
init,
)
def test_proto_conversions_isv_key_not_set(self) -> None:
init = init_arguments.StreamableGurobiInitArguments()
proto_init = gurobi_pb2.GurobiInitializerProto()
self.assert_protos_equiv(init.to_proto(), proto_init)
self.assertEqual(
init_arguments.streamable_gurobi_init_arguments_from_proto(proto_init),
init,
)
class StreamableSolverInitArgumentsTest(
absltest.TestCase, compare_proto.MathOptProtoAssertions
):
def test_proto_conversions_gurobi_set(self) -> None:
init = init_arguments.StreamableSolverInitArguments(
gurobi=init_arguments.StreamableGurobiInitArguments(
isv_key=init_arguments.GurobiISVKey(
name="cat", application_name="hat", expiration=4, key="bat"
)
)
)
proto_init = parameters_pb2.SolverInitializerProto(
gurobi=gurobi_pb2.GurobiInitializerProto(
isv_key=gurobi_pb2.GurobiInitializerProto.ISVKey(
name="cat", application_name="hat", expiration=4, key="bat"
)
)
)
self.assert_protos_equiv(init.to_proto(), proto_init)
self.assertEqual(
init_arguments.streamable_solver_init_arguments_from_proto(proto_init),
init,
)
def test_proto_conversions_gurobi_not_set(self) -> None:
init = init_arguments.StreamableSolverInitArguments()
proto_init = parameters_pb2.SolverInitializerProto()
self.assert_protos_equiv(init.to_proto(), proto_init)
self.assertEqual(
init_arguments.streamable_solver_init_arguments_from_proto(proto_init),
init,
)
if __name__ == "__main__":
absltest.main()

View File

@@ -65,6 +65,26 @@ from ortools.math_opt.python.errors import status_proto_to_exception
from ortools.math_opt.python.expressions import evaluate_expression
from ortools.math_opt.python.expressions import fast_sum
from ortools.math_opt.python.hash_model_storage import HashModelStorage
from ortools.math_opt.python.init_arguments import gurobi_isv_key_from_proto
from ortools.math_opt.python.init_arguments import GurobiISVKey
from ortools.math_opt.python.init_arguments import (
streamable_gurobi_init_arguments_from_proto,
)
from ortools.math_opt.python.init_arguments import (
streamable_solver_init_arguments_from_proto,
)
from ortools.math_opt.python.init_arguments import StreamableCpSatInitArguments
from ortools.math_opt.python.init_arguments import StreamableEcosInitArguments
from ortools.math_opt.python.init_arguments import StreamableGlopInitArguments
from ortools.math_opt.python.init_arguments import StreamableGlpkInitArguments
from ortools.math_opt.python.init_arguments import StreamableGScipInitArguments
from ortools.math_opt.python.init_arguments import StreamableGurobiInitArguments
from ortools.math_opt.python.init_arguments import StreamableHighsInitArguments
from ortools.math_opt.python.init_arguments import StreamableOsqpInitArguments
from ortools.math_opt.python.init_arguments import StreamablePdlpInitArguments
from ortools.math_opt.python.init_arguments import StreamableSantoriniInitArguments
from ortools.math_opt.python.init_arguments import StreamableScsInitArguments
from ortools.math_opt.python.init_arguments import StreamableSolverInitArguments
from ortools.math_opt.python.message_callback import list_message_callback
from ortools.math_opt.python.message_callback import log_messages
from ortools.math_opt.python.message_callback import printer_message_callback

View File

@@ -22,6 +22,7 @@ from absl.testing import absltest
from ortools.math_opt.python import callback
from ortools.math_opt.python import expressions
from ortools.math_opt.python import hash_model_storage
from ortools.math_opt.python import init_arguments
from ortools.math_opt.python import mathopt
from ortools.math_opt.python import message_callback
from ortools.math_opt.python import model
@@ -48,6 +49,7 @@ _MODULES_TO_CHECK: List[types.ModuleType] = [
callback,
expressions,
hash_model_storage,
init_arguments,
message_callback,
model,
model_parameters,

View File

@@ -325,6 +325,18 @@ class Termination:
problem_status: ProblemStatus = ProblemStatus()
objective_bounds: ObjectiveBounds = ObjectiveBounds()
def to_proto(self) -> result_pb2.TerminationProto:
"""Returns an equivalent protocol buffer to this Termination."""
return result_pb2.TerminationProto(
reason=self.reason.value,
limit=(
result_pb2.LIMIT_UNSPECIFIED if self.limit is None else self.limit.value
),
detail=self.detail,
problem_status=self.problem_status.to_proto(),
objective_bounds=self.objective_bounds.to_proto(),
)
def parse_termination(
termination_proto: result_pb2.TerminationProto,
@@ -930,6 +942,39 @@ class SolveResult:
f"variable_status: {type(variables).__name__!r}"
)
def to_proto(self) -> result_pb2.SolveResultProto:
"""Returns an equivalent protocol buffer for a SolveResult."""
proto = result_pb2.SolveResultProto(
termination=self.termination.to_proto(),
solutions=[s.to_proto() for s in self.solutions],
primal_rays=[r.to_proto() for r in self.primal_rays],
dual_rays=[r.to_proto() for r in self.dual_rays],
solve_stats=self.solve_stats.to_proto(),
)
# Ensure that at most solver has solver specific output.
existing_solver_specific_output = None
def has_solver_specific_output(solver_name: str) -> None:
nonlocal existing_solver_specific_output
if existing_solver_specific_output is not None:
raise ValueError(
"found solver specific output for both"
f" {existing_solver_specific_output} and {solver_name}"
)
existing_solver_specific_output = solver_name
if self.gscip_specific_output is not None:
has_solver_specific_output("gscip")
proto.gscip_output.CopyFrom(self.gscip_specific_output)
if self.osqp_specific_output is not None:
has_solver_specific_output("osqp")
proto.osqp_output.CopyFrom(self.osqp_specific_output)
if self.pdlp_specific_output is not None:
has_solver_specific_output("pdlp")
proto.pdlp_output.CopyFrom(self.pdlp_specific_output)
return proto
def _get_problem_status(
result_proto: result_pb2.SolveResultProto,

View File

@@ -16,6 +16,8 @@ import datetime
import math
from absl.testing import absltest
from ortools.pdlp import solve_log_pb2
from ortools.gscip import gscip_pb2
from ortools.math_opt import result_pb2
from ortools.math_opt import solution_pb2
from ortools.math_opt import sparse_containers_pb2
@@ -23,9 +25,10 @@ from ortools.math_opt.python import model
from ortools.math_opt.python import result
from ortools.math_opt.python import solution
from ortools.math_opt.python.testing import compare_proto
from ortools.math_opt.solvers import osqp_pb2
class ParseTerminationReason(compare_proto.MathOptProtoAssertions, absltest.TestCase):
class TerminationTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
def test_termination_unspecified(self) -> None:
termination_proto = result_pb2.TerminationProto(
@@ -54,7 +57,19 @@ class ParseTerminationReason(compare_proto.MathOptProtoAssertions, absltest.Test
):
result.parse_termination(termination_proto)
def test_termination_ok(self) -> None:
def test_termination_ok_proto_round_trip(self) -> None:
termination = result.Termination(
reason=result.TerminationReason.NO_SOLUTION_FOUND,
limit=result.Limit.OTHER,
detail="detail",
problem_status=result.ProblemStatus(
primal_status=result.FeasibilityStatus.FEASIBLE,
dual_status=result.FeasibilityStatus.INFEASIBLE,
primal_or_dual_infeasible=False,
),
objective_bounds=result.ObjectiveBounds(primal_bound=10, dual_bound=20),
)
termination_proto = result_pb2.TerminationProto(
reason=result_pb2.TERMINATION_REASON_NO_SOLUTION_FOUND,
limit=result_pb2.LIMIT_OTHER,
@@ -68,22 +83,12 @@ class ParseTerminationReason(compare_proto.MathOptProtoAssertions, absltest.Test
primal_bound=10, dual_bound=20
),
)
termination = result.parse_termination(termination_proto)
self.assertEqual(termination.reason, result.TerminationReason.NO_SOLUTION_FOUND)
self.assertEqual(termination.limit, result.Limit.OTHER)
self.assertEqual(termination.detail, "detail")
self.assertEqual(
termination.problem_status,
result.ProblemStatus(
primal_status=result.FeasibilityStatus.FEASIBLE,
dual_status=result.FeasibilityStatus.INFEASIBLE,
primal_or_dual_infeasible=False,
),
)
self.assertEqual(
termination.objective_bounds,
result.ObjectiveBounds(primal_bound=10, dual_bound=20),
)
# Test proto-> Termination
self.assertEqual(result.parse_termination(termination_proto), termination)
# Test Termination -> proto
self.assert_protos_equiv(termination.to_proto(), termination_proto)
class ParseProblemStatus(compare_proto.MathOptProtoAssertions, absltest.TestCase):
@@ -617,83 +622,105 @@ def _make_undetermined_result_proto() -> result_pb2.SolveResultProto:
primal_bound=math.inf,
dual_bound=-math.inf,
),
),
solutions=[
solution_pb2.SolutionProto(
primal_solution=solution_pb2.PrimalSolutionProto(
objective_value=2.0,
variable_values=sparse_containers_pb2.SparseDoubleVectorProto(
ids=[0], values=[1.0]
),
feasibility_status=solution_pb2.SOLUTION_STATUS_UNDETERMINED,
)
)
],
)
)
proto.solve_stats.problem_status.primal_status = (
result_pb2.FEASIBILITY_STATUS_UNDETERMINED
)
proto.solve_stats.problem_status.dual_status = (
result_pb2.FEASIBILITY_STATUS_UNDETERMINED
)
proto.solve_stats.problem_status.primal_or_dual_infeasible = False
proto.solve_stats.best_primal_bound = math.inf
proto.solve_stats.best_dual_bound = -math.inf
proto.solve_stats.solve_time.FromTimedelta(datetime.timedelta(minutes=2))
return proto
def _make_undetermined_solve_result() -> result.SolveResult:
return result.SolveResult(
termination=result.Termination(
reason=result.TerminationReason.NO_SOLUTION_FOUND,
limit=result.Limit.TIME,
problem_status=result.ProblemStatus(
primal_status=result.FeasibilityStatus.UNDETERMINED,
dual_status=result.FeasibilityStatus.UNDETERMINED,
),
objective_bounds=result.ObjectiveBounds(
primal_bound=math.inf, dual_bound=-math.inf
),
),
solve_stats=result.SolveStats(solve_time=datetime.timedelta(minutes=2)),
)
class SolveResultTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
def test_solve_result_gscip_output(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable()
res = _make_undetermined_solve_result()
res.gscip_specific_output = gscip_pb2.GScipOutput(status_detail="gscip_detail")
proto = _make_undetermined_result_proto()
proto.gscip_output.status_detail = "gscip_detail"
res = result.parse_solve_result(proto, mod)
assert res.gscip_specific_output is not None
self.assertEqual("gscip_detail", res.gscip_specific_output.status_detail)
def test_solve_result_no_gscip_output(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable()
proto = _make_undetermined_result_proto()
res = result.parse_solve_result(proto, mod)
self.assertIsNone(res.gscip_specific_output)
# proto -> result
actual_res = result.parse_solve_result(proto, mod)
self.assertIsNotNone(actual_res.gscip_specific_output)
assert actual_res.gscip_specific_output is not None
self.assertEqual("gscip_detail", actual_res.gscip_specific_output.status_detail)
self.assertIsNone(actual_res.pdlp_specific_output)
self.assertIsNone(actual_res.osqp_specific_output)
# result -> proto
self.assert_protos_equiv(res.to_proto(), proto)
def test_solve_result_osqp_output(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable()
proto = _make_undetermined_result_proto()
proto.osqp_output.initialized_underlying_solver = False
res = result.parse_solve_result(proto, mod)
assert res.osqp_specific_output is not None
self.assertFalse(res.osqp_specific_output.initialized_underlying_solver)
res = _make_undetermined_solve_result()
res.osqp_specific_output = osqp_pb2.OsqpOutput(
initialized_underlying_solver=True
)
def test_solve_result_no_osqp_output(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable()
proto = _make_undetermined_result_proto()
res = result.parse_solve_result(proto, mod)
self.assertIsNone(res.osqp_specific_output)
proto.osqp_output.initialized_underlying_solver = True
# proto -> result
actual_res = result.parse_solve_result(proto, mod)
self.assertIsNotNone(actual_res.osqp_specific_output)
assert actual_res.osqp_specific_output is not None
self.assertTrue(actual_res.osqp_specific_output.initialized_underlying_solver)
self.assertIsNone(actual_res.pdlp_specific_output)
self.assertIsNone(actual_res.gscip_specific_output)
# result -> proto
self.assert_protos_equiv(res.to_proto(), proto)
def test_solve_result_pdlp_output(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable()
proto = _make_undetermined_result_proto()
proto.pdlp_output.convergence_information.corrected_dual_objective = 2.0
res = result.parse_solve_result(proto, mod)
assert res.pdlp_specific_output is not None
self.assertEqual(
res.pdlp_specific_output.convergence_information.corrected_dual_objective,
2.0,
res = _make_undetermined_solve_result()
res.pdlp_specific_output = result_pb2.SolveResultProto.PdlpOutput(
convergence_information=solve_log_pb2.ConvergenceInformation(
primal_objective=1.0
)
)
def test_solve_result_no_pdlp_output(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable()
proto = _make_undetermined_result_proto()
res = result.parse_solve_result(proto, mod)
self.assertIsNone(res.pdlp_specific_output)
proto.pdlp_output.convergence_information.primal_objective = 1.0
# proto -> result
actual_res = result.parse_solve_result(proto, mod)
self.assertIsNotNone(actual_res.pdlp_specific_output)
assert actual_res.pdlp_specific_output is not None
self.assertEqual(
actual_res.pdlp_specific_output.convergence_information.primal_objective,
1.0,
)
self.assertIsNone(actual_res.osqp_specific_output)
self.assertIsNone(actual_res.gscip_specific_output)
# result -> proto
self.assert_protos_equiv(res.to_proto(), proto)
def test_multiple_solver_specific_outputs_error(self) -> None:
res = _make_undetermined_solve_result()
res.gscip_specific_output = gscip_pb2.GScipOutput(status_detail="gscip_detail")
res.osqp_specific_output = osqp_pb2.OsqpOutput(
initialized_underlying_solver=False
)
with self.assertRaisesRegex(ValueError, "solver specific output"):
res.to_proto()
def test_solve_result_from_proto_missing_bounds_in_termination(
self,
@@ -1048,6 +1075,78 @@ class SolveResultTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
self.assertEqual(20, res.termination.objective_bounds.dual_bound)
self.assertIsNone(res.gscip_specific_output)
def test_to_proto_round_trip(self) -> None:
mod = model.Model(name="test_model")
x = mod.add_binary_variable(name="x")
c = mod.add_linear_constraint(lb=0.0, ub=1.0, name="c")
s = solution.Solution(
primal_solution=solution.PrimalSolution(
variable_values={x: 1.0},
objective_value=2.0,
feasibility_status=solution.SolutionStatus.FEASIBLE,
)
)
r = result.SolveResult(
termination=result.Termination(
reason=result.TerminationReason.FEASIBLE,
limit=result.Limit.TIME,
problem_status=result.ProblemStatus(
primal_status=result.FeasibilityStatus.FEASIBLE,
dual_status=result.FeasibilityStatus.UNDETERMINED,
),
),
solve_stats=result.SolveStats(
node_count=3, solve_time=datetime.timedelta(seconds=4)
),
solutions=[s],
primal_rays=[solution.PrimalRay(variable_values={x: 4.0})],
dual_rays=[solution.DualRay(reduced_costs={x: 5.0}, dual_values={c: 6.0})],
)
s_proto = solution_pb2.SolutionProto(
primal_solution=solution_pb2.PrimalSolutionProto(
objective_value=2.0,
feasibility_status=solution_pb2.SOLUTION_STATUS_FEASIBLE,
variable_values=sparse_containers_pb2.SparseDoubleVectorProto(
ids=[0], values=[1.0]
),
)
)
r_proto = result_pb2.SolveResultProto(
termination=result_pb2.TerminationProto(
reason=result_pb2.TERMINATION_REASON_FEASIBLE,
limit=result_pb2.LIMIT_TIME,
problem_status=result_pb2.ProblemStatusProto(
primal_status=result_pb2.FEASIBILITY_STATUS_FEASIBLE,
dual_status=result_pb2.FEASIBILITY_STATUS_UNDETERMINED,
),
),
solve_stats=result_pb2.SolveStatsProto(node_count=3),
solutions=[s_proto],
primal_rays=[
solution_pb2.PrimalRayProto(
variable_values=sparse_containers_pb2.SparseDoubleVectorProto(
ids=[0], values=[4.0]
)
)
],
dual_rays=[
solution_pb2.DualRayProto(
reduced_costs=sparse_containers_pb2.SparseDoubleVectorProto(
ids=[0], values=[5.0]
),
dual_values=sparse_containers_pb2.SparseDoubleVectorProto(
ids=[0], values=[6.0]
),
)
],
)
r_proto.solve_stats.solve_time.FromTimedelta(datetime.timedelta(seconds=4))
self.assert_protos_equiv(r.to_proto(), r_proto)
self.assertEqual(result.parse_solve_result(r_proto, mod), r)
if __name__ == "__main__":
absltest.main()

View File

@@ -164,6 +164,14 @@ class PrimalRay:
default_factory=dict
)
def to_proto(self) -> solution_pb2.PrimalRayProto:
"""Returns an equivalent proto to this PrimalRay."""
return solution_pb2.PrimalRayProto(
variable_values=sparse_containers.to_sparse_double_vector_proto(
self.variable_values
)
)
def parse_primal_ray(proto: solution_pb2.PrimalRayProto, mod: model.Model) -> PrimalRay:
"""Returns an equivalent PrimalRay from the input proto."""
@@ -278,6 +286,17 @@ class DualRay:
)
reduced_costs: Dict[model.Variable, float] = dataclasses.field(default_factory=dict)
def to_proto(self) -> solution_pb2.DualRayProto:
"""Returns an equivalent proto to this PrimalRay."""
return solution_pb2.DualRayProto(
dual_values=sparse_containers.to_sparse_double_vector_proto(
self.dual_values
),
reduced_costs=sparse_containers.to_sparse_double_vector_proto(
self.reduced_costs
),
)
def parse_dual_ray(proto: solution_pb2.DualRayProto, mod: model.Model) -> DualRay:
"""Returns an equivalent DualRay from the input proto."""

View File

@@ -76,17 +76,24 @@ class ParsePrimalSolutionTest(compare_proto.MathOptProtoAssertions, absltest.Tes
solution.parse_primal_solution(proto, mod)
class ParsePrimalRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
class PrimalRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
def test_parse(self) -> None:
def test_proto_round_trip(self) -> None:
mod = model.Model(name="test_model")
x = mod.add_binary_variable(name="x")
y = mod.add_binary_variable(name="y")
proto = solution_pb2.PrimalRayProto()
proto.variable_values.ids[:] = [0, 1]
proto.variable_values.values[:] = [1.0, 1.0]
actual = solution.parse_primal_ray(proto, mod)
self.assertDictEqual({x: 1.0, y: 1.0}, actual.variable_values)
ray = solution.PrimalRay(variable_values={x: 1.0, y: 1.0})
ray_proto = solution_pb2.PrimalRayProto()
ray_proto.variable_values.ids[:] = [0, 1]
ray_proto.variable_values.values[:] = [1.0, 1.0]
# Test proto -> model
parsed_ray = solution.parse_primal_ray(ray_proto, mod)
self.assertDictEqual({x: 1.0, y: 1.0}, parsed_ray.variable_values)
# Test model -> proto
exported_ray = ray.to_proto()
self.assert_protos_equiv(exported_ray, ray_proto)
class ParseDualSolutionTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
@@ -151,22 +158,33 @@ class ParseDualSolutionTest(compare_proto.MathOptProtoAssertions, absltest.TestC
solution.parse_dual_solution(proto, mod)
class ParseDualRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
class DualRayTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
def test_parse(self) -> None:
def test_proto_round_trip(self) -> None:
mod = model.Model(name="test_model")
x = mod.add_binary_variable(name="x")
y = mod.add_binary_variable(name="y")
c = mod.add_linear_constraint(lb=0.0, ub=1.0, name="c")
d = mod.add_linear_constraint(lb=0.0, ub=1.0, name="d")
proto = solution_pb2.DualRayProto()
proto.dual_values.ids[:] = [0, 1]
proto.dual_values.values[:] = [0.0, 1.0]
proto.reduced_costs.ids[:] = [0, 1]
proto.reduced_costs.values[:] = [10.0, 0.0]
actual = solution.parse_dual_ray(proto, mod)
self.assertDictEqual({x: 10.0, y: 0.0}, actual.reduced_costs)
self.assertDictEqual({c: 0.0, d: 1.0}, actual.dual_values)
dual_ray = solution.DualRay(
dual_values={c: 0.0, d: 1.0}, reduced_costs={x: 10.0, y: 0.0}
)
dual_ray_proto = solution_pb2.DualRayProto()
dual_ray_proto.dual_values.ids[:] = [0, 1]
dual_ray_proto.dual_values.values[:] = [0.0, 1.0]
dual_ray_proto.reduced_costs.ids[:] = [0, 1]
dual_ray_proto.reduced_costs.values[:] = [10.0, 0.0]
# Test proto -> dual ray
parsed_ray = solution.parse_dual_ray(dual_ray_proto, mod)
self.assertDictEqual(dual_ray.reduced_costs, parsed_ray.reduced_costs)
self.assertDictEqual(dual_ray.dual_values, parsed_ray.dual_values)
# Test dual ray -> proto
exported_proto = dual_ray.to_proto()
self.assert_protos_equiv(exported_proto, dual_ray_proto)
class BasisTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):

View File

@@ -21,6 +21,7 @@ from ortools.math_opt.core.python import solver
from ortools.math_opt.python import callback
from ortools.math_opt.python import compute_infeasible_subsystem_result
from ortools.math_opt.python import errors
from ortools.math_opt.python import init_arguments
from ortools.math_opt.python import message_callback
from ortools.math_opt.python import model
from ortools.math_opt.python import model_parameters
@@ -40,6 +41,7 @@ def solve(
msg_cb: Optional[message_callback.SolveMessageCallback] = None,
callback_reg: Optional[callback.CallbackRegistration] = None,
cb: Optional[SolveCallback] = None,
streamable_init_args: Optional[init_arguments.StreamableSolverInitArguments] = None,
) -> result.SolveResult:
"""Solves an optimization model.
@@ -56,6 +58,7 @@ def solve(
callback_reg: Configures when the callback will be invoked (if provided) and
what data will be collected to access in the callback.
cb: A callback that will be called periodically as the solver runs.
streamable_init_args: Configuration for initializing the underlying solver.
Returns:
A SolveResult containing the termination reason, solution(s) and stats.
@@ -68,6 +71,9 @@ def solve(
params = params or parameters.SolveParameters()
model_params = model_params or model_parameters.ModelSolveParameters()
callback_reg = callback_reg or callback.CallbackRegistration()
streamable_init_args = (
streamable_init_args or init_arguments.StreamableSolverInitArguments()
)
model_proto = opt_model.export_model()
proto_cb = None
if cb is not None:
@@ -79,7 +85,7 @@ def solve(
proto_result = solver.solve(
model_proto,
solver_type.value,
parameters_pb2.SolverInitializerProto(),
streamable_init_args.to_proto(),
params.to_proto(),
model_params.to_proto(),
msg_cb,
@@ -98,6 +104,7 @@ def compute_infeasible_subsystem(
*,
params: Optional[parameters.SolveParameters] = None,
msg_cb: Optional[message_callback.SolveMessageCallback] = None,
streamable_init_args: Optional[init_arguments.StreamableSolverInitArguments] = None,
) -> compute_infeasible_subsystem_result.ComputeInfeasibleSubsystemResult:
"""Computes an infeasible subsystem of the input model.
@@ -107,6 +114,7 @@ def compute_infeasible_subsystem(
August 2023, the only supported solver is Gurobi.
params: Configuration of the underlying solver.
msg_cb: A callback that gives back the underlying solver's logs by the line.
streamable_init_args: Configuration for initializing the underlying solver.
Returns:
An `ComputeInfeasibleSubsystemResult` where `feasibility` indicates if the
@@ -116,13 +124,16 @@ def compute_infeasible_subsystem(
RuntimeError: on invalid inputs or an internal solver error.
"""
params = params or parameters.SolveParameters()
streamable_init_args = (
streamable_init_args or init_arguments.StreamableSolverInitArguments()
)
model_proto = opt_model.export_model()
# Solve
try:
proto_result = solver.compute_infeasible_subsystem(
model_proto,
solver_type.value,
parameters_pb2.SolverInitializerProto(),
streamable_init_args.to_proto(),
params.to_proto(),
msg_cb,
None,
@@ -163,7 +174,18 @@ class IncrementalSolver:
When it is not possible to use `with`, the close() method can be called.
"""
def __init__(self, opt_model: model.Model, solver_type: parameters.SolverType):
def __init__(
self,
opt_model: model.Model,
solver_type: parameters.SolverType,
*,
streamable_init_args: Optional[
init_arguments.StreamableSolverInitArguments
] = None,
):
streamable_init_args = (
streamable_init_args or init_arguments.StreamableSolverInitArguments()
)
self._model = opt_model
self._solver_type = solver_type
self._update_tracker = self._model.add_update_tracker()
@@ -171,7 +193,7 @@ class IncrementalSolver:
self._proto_solver = solver.new(
solver_type.value,
self._model.export_model(),
parameters_pb2.SolverInitializerProto(),
streamable_init_args.to_proto(),
)
except StatusNotOk as e:
raise _status_not_ok_to_exception(e) from None

View File

@@ -19,8 +19,10 @@ machine.
"""
from absl.testing import absltest
from ortools.gurobi.isv.secret import gurobi_test_isv_key
from ortools.math_opt.python import callback
from ortools.math_opt.python import compute_infeasible_subsystem_result
from ortools.math_opt.python import init_arguments
from ortools.math_opt.python import model
from ortools.math_opt.python import parameters
from ortools.math_opt.python import result
@@ -29,6 +31,18 @@ from ortools.math_opt.python import solve
_Bounds = compute_infeasible_subsystem_result.ModelSubsetBounds
_bad_isv_key = init_arguments.GurobiISVKey(
name="cat", application_name="hat", expiration=10, key="bat"
)
def _init_args(
gurobi_key: init_arguments.GurobiISVKey,
) -> init_arguments.StreamableSolverInitArguments:
return init_arguments.StreamableSolverInitArguments(
gurobi=init_arguments.StreamableGurobiInitArguments(isv_key=gurobi_key)
)
class SolveTest(absltest.TestCase):
@@ -89,6 +103,95 @@ class SolveTest(absltest.TestCase):
)
self.assertEmpty(iis.infeasible_subsystem.variable_integrality)
def test_solve_valid_isv_success(self):
mod = model.Model()
x = mod.add_binary_variable()
mod.maximize(x)
res = solve.solve(
mod,
parameters.SolverType.GUROBI,
streamable_init_args=_init_args(
gurobi_test_isv_key.google_test_isv_key_placeholder()
),
)
self.assertEqual(
res.termination.reason,
result.TerminationReason.OPTIMAL,
msg=res.termination,
)
self.assertAlmostEqual(1.0, res.termination.objective_bounds.primal_bound)
def test_solve_wrong_isv_error(self):
mod = model.Model()
x = mod.add_binary_variable()
mod.maximize(x)
with self.assertRaisesRegex(
ValueError, "failed to create Gurobi primary environment with ISV key"
):
solve.solve(
mod,
parameters.SolverType.GUROBI,
streamable_init_args=_init_args(_bad_isv_key),
)
def test_incremental_solver_valid_isv_success(self):
mod = model.Model()
x = mod.add_binary_variable()
mod.maximize(x)
s = solve.IncrementalSolver(
mod,
parameters.SolverType.GUROBI,
streamable_init_args=_init_args(
gurobi_test_isv_key.google_test_isv_key_placeholder()
),
)
res = s.solve()
self.assertEqual(
res.termination.reason,
result.TerminationReason.OPTIMAL,
msg=res.termination,
)
self.assertAlmostEqual(1.0, res.termination.objective_bounds.primal_bound)
def test_incremental_solver_wrong_isv_error(self):
mod = model.Model()
x = mod.add_binary_variable()
mod.maximize(x)
with self.assertRaisesRegex(
ValueError, "failed to create Gurobi primary environment with ISV key"
):
solve.IncrementalSolver(
mod,
parameters.SolverType.GUROBI,
streamable_init_args=_init_args(_bad_isv_key),
)
def test_compute_infeasible_subsystem_valid_isv_success(self):
mod = model.Model()
x = mod.add_binary_variable()
mod.add_linear_constraint(x >= 3.0)
res = solve.compute_infeasible_subsystem(
mod,
parameters.SolverType.GUROBI,
streamable_init_args=_init_args(
gurobi_test_isv_key.google_test_isv_key_placeholder()
),
)
self.assertEqual(res.feasibility, result.FeasibilityStatus.INFEASIBLE)
def test_compute_infeasible_subsystem_wrong_isv_error(self):
mod = model.Model()
x = mod.add_binary_variable()
mod.add_linear_constraint(x >= 3.0)
with self.assertRaisesRegex(
ValueError, "failed to create Gurobi primary environment with ISV key"
):
solve.compute_infeasible_subsystem(
mod,
parameters.SolverType.GUROBI,
streamable_init_args=_init_args(_bad_isv_key),
)
if __name__ == "__main__":
absltest.main()

View File

@@ -11,8 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library")
load("@rules_cc//cc:defs.bzl", "cc_proto_library")
load("@rules_python//python:proto.bzl", "py_proto_library")
package(default_visibility = ["//ortools/math_opt:__subpackages__"])
@@ -312,9 +312,9 @@ cc_test(
cc_test(
name = "cp_sat_solver_test",
timeout = "eternal",
srcs = ["cp_sat_solver_test.cc"],
shard_count = 10,
timeout = "eternal",
deps = [
":cp_sat_solver",
"//ortools/base:gmock_main",

View File

@@ -66,6 +66,7 @@
#include "ortools/math_opt/solvers/message_callback_data.h"
#include "ortools/util/solve_interrupter.h"
#include "ortools/util/status_macros.h"
#include "simplex/SimplexConst.h"
#include "util/HighsInt.h"
namespace operations_research::math_opt {

View File

@@ -182,11 +182,9 @@ cc_library(
":linear_constraint_storage",
":model_storage_types",
":objective_storage",
":sparse_matrix",
":update_trackers",
":variable_storage",
"//ortools/base:intops",
"//ortools/base:map_util",
"//ortools/base:status_macros",
"//ortools/math_opt:model_cc_proto",
"//ortools/math_opt:model_update_cc_proto",

View File

@@ -21,15 +21,12 @@
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/base/map_util.h"
#include "ortools/base/status_macros.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/core/model_summary.h"
@@ -41,7 +38,6 @@
#include "ortools/math_opt/sparse_containers.pb.h"
#include "ortools/math_opt/storage/iterators.h"
#include "ortools/math_opt/storage/linear_constraint_storage.h"
#include "ortools/math_opt/storage/sparse_matrix.h"
#include "ortools/math_opt/storage/update_trackers.h"
#include "ortools/math_opt/storage/variable_storage.h"
#include "ortools/math_opt/validators/model_validator.h"
@@ -84,19 +80,21 @@ absl::StatusOr<std::unique_ptr<ModelStorage>> ModelStorage::FromModelProto(
model_proto.linear_constraint_matrix());
// Add quadratic constraints.
storage->quadratic_constraints_.AddConstraints(
storage->copyable_data_.quadratic_constraints.AddConstraints(
model_proto.quadratic_constraints());
// Add SOC constraints.
storage->soc_constraints_.AddConstraints(
storage->copyable_data_.soc_constraints.AddConstraints(
model_proto.second_order_cone_constraints());
// Add SOS constraints.
storage->sos1_constraints_.AddConstraints(model_proto.sos1_constraints());
storage->sos2_constraints_.AddConstraints(model_proto.sos2_constraints());
storage->copyable_data_.sos1_constraints.AddConstraints(
model_proto.sos1_constraints());
storage->copyable_data_.sos2_constraints.AddConstraints(
model_proto.sos2_constraints());
// Add indicator constraints.
storage->indicator_constraints_.AddConstraints(
storage->copyable_data_.indicator_constraints.AddConstraints(
model_proto.indicator_constraints());
return storage;
@@ -147,42 +145,22 @@ void ModelStorage::UpdateLinearConstraintCoefficients(
std::unique_ptr<ModelStorage> ModelStorage::Clone(
const std::optional<absl::string_view> new_name) const {
ModelProto model_proto = ExportModel();
// We leverage the private copy constructor that copies copyable_data_ but not
// update_trackers_ here.
std::unique_ptr<ModelStorage> clone =
absl::WrapUnique(new ModelStorage(*this));
if (new_name.has_value()) {
model_proto.set_name(std::string(*new_name));
clone->copyable_data_.name = *new_name;
}
absl::StatusOr<std::unique_ptr<ModelStorage>> clone =
ModelStorage::FromModelProto(model_proto);
// Unless there is a very serious bug, a model exported by ExportModel()
// should always be valid.
CHECK_OK(clone.status());
// Update the next ids so that the clone does not reused any deleted id from
// the original.
clone.value()->ensure_next_variable_id_at_least(next_variable_id());
clone.value()->ensure_next_auxiliary_objective_id_at_least(
next_auxiliary_objective_id());
clone.value()->ensure_next_linear_constraint_id_at_least(
next_linear_constraint_id());
clone.value()->ensure_next_constraint_id_at_least(
next_constraint_id<QuadraticConstraintId>());
clone.value()->ensure_next_constraint_id_at_least(
next_constraint_id<SecondOrderConeConstraintId>());
clone.value()->ensure_next_constraint_id_at_least(
next_constraint_id<Sos1ConstraintId>());
clone.value()->ensure_next_constraint_id_at_least(
next_constraint_id<Sos2ConstraintId>());
clone.value()->ensure_next_constraint_id_at_least(
next_constraint_id<IndicatorConstraintId>());
return std::move(clone).value();
return clone;
}
VariableId ModelStorage::AddVariable(const double lower_bound,
const double upper_bound,
const bool is_integer,
const absl::string_view name) {
return variables_.Add(lower_bound, upper_bound, is_integer, name);
return copyable_data_.variables.Add(lower_bound, upper_bound, is_integer,
name);
}
void ModelStorage::AddVariables(const VariablesProto& variables) {
@@ -200,39 +178,39 @@ void ModelStorage::AddVariables(const VariablesProto& variables) {
}
void ModelStorage::DeleteVariable(const VariableId id) {
CHECK(variables_.contains(id));
CHECK(copyable_data_.variables.contains(id));
const auto& trackers = update_trackers_.GetUpdatedTrackers();
// Reuse output of GetUpdatedTrackers() only once to ensure a consistent view,
// do not call UpdateAndGetLinearConstraintDiffs() etc.
objectives_.DeleteVariable(
copyable_data_.objectives.DeleteVariable(
id,
MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_objective>(trackers));
linear_constraints_.DeleteVariable(
copyable_data_.linear_constraints.DeleteVariable(
id,
MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_linear_constraints>(
trackers));
quadratic_constraints_.DeleteVariable(id);
soc_constraints_.DeleteVariable(id);
sos1_constraints_.DeleteVariable(id);
sos2_constraints_.DeleteVariable(id);
indicator_constraints_.DeleteVariable(id);
variables_.Delete(
copyable_data_.quadratic_constraints.DeleteVariable(id);
copyable_data_.soc_constraints.DeleteVariable(id);
copyable_data_.sos1_constraints.DeleteVariable(id);
copyable_data_.sos2_constraints.DeleteVariable(id);
copyable_data_.indicator_constraints.DeleteVariable(id);
copyable_data_.variables.Delete(
id,
MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_variables>(trackers));
}
std::vector<VariableId> ModelStorage::variables() const {
return variables_.Variables();
return copyable_data_.variables.Variables();
}
std::vector<VariableId> ModelStorage::SortedVariables() const {
return variables_.SortedVariables();
return copyable_data_.variables.SortedVariables();
}
LinearConstraintId ModelStorage::AddLinearConstraint(
const double lower_bound, const double upper_bound,
const absl::string_view name) {
return linear_constraints_.Add(lower_bound, upper_bound, name);
return copyable_data_.linear_constraints.Add(lower_bound, upper_bound, name);
}
void ModelStorage::AddLinearConstraints(
@@ -253,16 +231,17 @@ void ModelStorage::AddLinearConstraints(
}
void ModelStorage::DeleteLinearConstraint(const LinearConstraintId id) {
CHECK(linear_constraints_.contains(id));
linear_constraints_.Delete(id, UpdateAndGetLinearConstraintDiffs());
CHECK(copyable_data_.linear_constraints.contains(id));
copyable_data_.linear_constraints.Delete(id,
UpdateAndGetLinearConstraintDiffs());
}
std::vector<LinearConstraintId> ModelStorage::LinearConstraints() const {
return linear_constraints_.LinearConstraints();
return copyable_data_.linear_constraints.LinearConstraints();
}
std::vector<LinearConstraintId> ModelStorage::SortedLinearConstraints() const {
return linear_constraints_.SortedLinearConstraints();
return copyable_data_.linear_constraints.SortedLinearConstraints();
}
void ModelStorage::AddAuxiliaryObjectives(
@@ -285,23 +264,26 @@ void ModelStorage::AddAuxiliaryObjectives(
// tries to create a very long RepeatedField.
ModelProto ModelStorage::ExportModel(const bool remove_names) const {
ModelProto result;
result.set_name(name_);
*result.mutable_variables() = variables_.Proto();
result.set_name(copyable_data_.name);
*result.mutable_variables() = copyable_data_.variables.Proto();
{
auto [primary, auxiliary] = objectives_.Proto();
auto [primary, auxiliary] = copyable_data_.objectives.Proto();
*result.mutable_objective() = std::move(primary);
*result.mutable_auxiliary_objectives() = std::move(auxiliary);
}
{
auto [constraints, matrix] = linear_constraints_.Proto();
auto [constraints, matrix] = copyable_data_.linear_constraints.Proto();
*result.mutable_linear_constraints() = std::move(constraints);
*result.mutable_linear_constraint_matrix() = std::move(matrix);
}
*result.mutable_quadratic_constraints() = quadratic_constraints_.Proto();
*result.mutable_second_order_cone_constraints() = soc_constraints_.Proto();
*result.mutable_sos1_constraints() = sos1_constraints_.Proto();
*result.mutable_sos2_constraints() = sos2_constraints_.Proto();
*result.mutable_indicator_constraints() = indicator_constraints_.Proto();
*result.mutable_quadratic_constraints() =
copyable_data_.quadratic_constraints.Proto();
*result.mutable_second_order_cone_constraints() =
copyable_data_.soc_constraints.Proto();
*result.mutable_sos1_constraints() = copyable_data_.sos1_constraints.Proto();
*result.mutable_sos2_constraints() = copyable_data_.sos2_constraints.Proto();
*result.mutable_indicator_constraints() =
copyable_data_.indicator_constraints.Proto();
// Performance can be improved when remove_names is true by just not
// extracting the names above instead of clearing them below, but this will
// be more code, see discussion on cl/549469633 and prototype in cl/549369764.
@@ -317,15 +299,19 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate(
// We must detect the empty case to prevent unneeded copies and merging in
// ExportModelUpdate().
if (storage.variables_.diff_is_empty(dirty_variables) &&
storage.objectives_.diff_is_empty(dirty_objective) &&
storage.linear_constraints_.diff_is_empty(dirty_linear_constraints) &&
storage.quadratic_constraints_.diff_is_empty(
if (storage.copyable_data_.variables.diff_is_empty(dirty_variables) &&
storage.copyable_data_.objectives.diff_is_empty(dirty_objective) &&
storage.copyable_data_.linear_constraints.diff_is_empty(
dirty_linear_constraints) &&
storage.copyable_data_.quadratic_constraints.diff_is_empty(
dirty_quadratic_constraints) &&
storage.soc_constraints_.diff_is_empty(dirty_soc_constraints) &&
storage.sos1_constraints_.diff_is_empty(dirty_sos1_constraints) &&
storage.sos2_constraints_.diff_is_empty(dirty_sos2_constraints) &&
storage.indicator_constraints_.diff_is_empty(
storage.copyable_data_.soc_constraints.diff_is_empty(
dirty_soc_constraints) &&
storage.copyable_data_.sos1_constraints.diff_is_empty(
dirty_sos1_constraints) &&
storage.copyable_data_.sos2_constraints.diff_is_empty(
dirty_sos2_constraints) &&
storage.copyable_data_.indicator_constraints.diff_is_empty(
dirty_indicator_constraints)) {
return std::nullopt;
}
@@ -335,18 +321,19 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate(
// Variable/constraint deletions.
{
VariableStorage::UpdateResult variable_update =
storage.variables_.Update(dirty_variables);
storage.copyable_data_.variables.Update(dirty_variables);
*result.mutable_deleted_variable_ids() = std::move(variable_update.deleted);
*result.mutable_variable_updates() = std::move(variable_update.updates);
*result.mutable_new_variables() = std::move(variable_update.creates);
}
const std::vector<VariableId> new_variables =
storage.variables_.VariablesFrom(dirty_variables.checkpoint);
storage.copyable_data_.variables.VariablesFrom(
dirty_variables.checkpoint);
// Linear constraint updates
{
LinearConstraintStorage::UpdateResult lin_con_update =
storage.linear_constraints_.Update(
storage.copyable_data_.linear_constraints.Update(
dirty_linear_constraints, dirty_variables.deleted, new_variables);
*result.mutable_deleted_linear_constraint_ids() =
std::move(lin_con_update.deleted);
@@ -360,25 +347,27 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate(
// Quadratic constraint updates
*result.mutable_quadratic_constraint_updates() =
storage.quadratic_constraints_.Update(dirty_quadratic_constraints);
storage.copyable_data_.quadratic_constraints.Update(
dirty_quadratic_constraints);
// Second-order cone constraint updates
*result.mutable_second_order_cone_constraint_updates() =
storage.soc_constraints_.Update(dirty_soc_constraints);
storage.copyable_data_.soc_constraints.Update(dirty_soc_constraints);
// SOS constraint updates
*result.mutable_sos1_constraint_updates() =
storage.sos1_constraints_.Update(dirty_sos1_constraints);
storage.copyable_data_.sos1_constraints.Update(dirty_sos1_constraints);
*result.mutable_sos2_constraint_updates() =
storage.sos2_constraints_.Update(dirty_sos2_constraints);
storage.copyable_data_.sos2_constraints.Update(dirty_sos2_constraints);
// Indicator constraint updates
*result.mutable_indicator_constraint_updates() =
storage.indicator_constraints_.Update(dirty_indicator_constraints);
storage.copyable_data_.indicator_constraints.Update(
dirty_indicator_constraints);
// Update the objective
{
auto [primary, auxiliary] = storage.objectives_.Update(
auto [primary, auxiliary] = storage.copyable_data_.objectives.Update(
dirty_objective, dirty_variables.deleted, new_variables);
*result.mutable_objective_updates() = std::move(primary);
*result.mutable_auxiliary_objectives_updates() = std::move(auxiliary);
@@ -392,25 +381,29 @@ ModelStorage::UpdateTrackerData::ExportModelUpdate(
void ModelStorage::UpdateTrackerData::AdvanceCheckpoint(
const ModelStorage& storage) {
storage.variables_.AdvanceCheckpointInDiff(dirty_variables);
storage.objectives_.AdvanceCheckpointInDiff(dirty_variables.checkpoint,
dirty_objective);
storage.linear_constraints_.AdvanceCheckpointInDiff(
storage.copyable_data_.variables.AdvanceCheckpointInDiff(dirty_variables);
storage.copyable_data_.objectives.AdvanceCheckpointInDiff(
dirty_variables.checkpoint, dirty_objective);
storage.copyable_data_.linear_constraints.AdvanceCheckpointInDiff(
dirty_variables.checkpoint, dirty_linear_constraints);
storage.quadratic_constraints_.AdvanceCheckpointInDiff(
storage.copyable_data_.quadratic_constraints.AdvanceCheckpointInDiff(
dirty_quadratic_constraints);
storage.soc_constraints_.AdvanceCheckpointInDiff(dirty_soc_constraints);
storage.sos1_constraints_.AdvanceCheckpointInDiff(dirty_sos1_constraints);
storage.sos2_constraints_.AdvanceCheckpointInDiff(dirty_sos2_constraints);
storage.indicator_constraints_.AdvanceCheckpointInDiff(
storage.copyable_data_.soc_constraints.AdvanceCheckpointInDiff(
dirty_soc_constraints);
storage.copyable_data_.sos1_constraints.AdvanceCheckpointInDiff(
dirty_sos1_constraints);
storage.copyable_data_.sos2_constraints.AdvanceCheckpointInDiff(
dirty_sos2_constraints);
storage.copyable_data_.indicator_constraints.AdvanceCheckpointInDiff(
dirty_indicator_constraints);
}
UpdateTrackerId ModelStorage::NewUpdateTracker() {
return update_trackers_.NewUpdateTracker(
variables_, objectives_, linear_constraints_, quadratic_constraints_,
soc_constraints_, sos1_constraints_, sos2_constraints_,
indicator_constraints_);
copyable_data_.variables, copyable_data_.objectives,
copyable_data_.linear_constraints, copyable_data_.quadratic_constraints,
copyable_data_.soc_constraints, copyable_data_.sos1_constraints,
copyable_data_.sos2_constraints, copyable_data_.indicator_constraints);
}
void ModelStorage::DeleteUpdateTracker(const UpdateTrackerId update_tracker) {
@@ -438,50 +431,50 @@ absl::Status ModelStorage::ApplyUpdateProto(
RETURN_IF_ERROR(summary.variables.Insert(id.value(), variable_name(id)))
<< "invalid variable id in model";
}
RETURN_IF_ERROR(
summary.variables.SetNextFreeId(variables_.next_id().value()));
RETURN_IF_ERROR(summary.variables.SetNextFreeId(
copyable_data_.variables.next_id().value()));
for (const AuxiliaryObjectiveId id : SortedAuxiliaryObjectives()) {
RETURN_IF_ERROR(
summary.auxiliary_objectives.Insert(id.value(), objective_name(id)))
<< "invalid auxiliary objective id in model";
}
RETURN_IF_ERROR(summary.auxiliary_objectives.SetNextFreeId(
objectives_.next_id().value()));
copyable_data_.objectives.next_id().value()));
for (const LinearConstraintId id : SortedLinearConstraints()) {
RETURN_IF_ERROR(summary.linear_constraints.Insert(
id.value(), linear_constraint_name(id)))
<< "invalid linear constraint id in model";
}
RETURN_IF_ERROR(summary.linear_constraints.SetNextFreeId(
linear_constraints_.next_id().value()));
copyable_data_.linear_constraints.next_id().value()));
for (const auto id : SortedConstraints<QuadraticConstraintId>()) {
RETURN_IF_ERROR(summary.quadratic_constraints.Insert(
id.value(), quadratic_constraints_.data(id).name))
id.value(), copyable_data_.quadratic_constraints.data(id).name))
<< "invalid quadratic constraint id in model";
}
RETURN_IF_ERROR(summary.quadratic_constraints.SetNextFreeId(
quadratic_constraints_.next_id().value()));
copyable_data_.quadratic_constraints.next_id().value()));
for (const auto id : SortedConstraints<SecondOrderConeConstraintId>()) {
RETURN_IF_ERROR(summary.second_order_cone_constraints.Insert(
id.value(), soc_constraints_.data(id).name))
id.value(), copyable_data_.soc_constraints.data(id).name))
<< "invalid second-order cone constraint id in model";
}
RETURN_IF_ERROR(summary.second_order_cone_constraints.SetNextFreeId(
soc_constraints_.next_id().value()));
copyable_data_.soc_constraints.next_id().value()));
for (const Sos1ConstraintId id : SortedConstraints<Sos1ConstraintId>()) {
RETURN_IF_ERROR(summary.sos1_constraints.Insert(
id.value(), constraint_data(id).name()))
<< "invalid SOS1 constraint id in model";
}
RETURN_IF_ERROR(summary.sos1_constraints.SetNextFreeId(
sos1_constraints_.next_id().value()));
copyable_data_.sos1_constraints.next_id().value()));
for (const Sos2ConstraintId id : SortedConstraints<Sos2ConstraintId>()) {
RETURN_IF_ERROR(summary.sos2_constraints.Insert(
id.value(), constraint_data(id).name()))
<< "invalid SOS2 constraint id in model";
}
RETURN_IF_ERROR(summary.sos2_constraints.SetNextFreeId(
sos2_constraints_.next_id().value()));
copyable_data_.sos2_constraints.next_id().value()));
for (const IndicatorConstraintId id :
SortedConstraints<IndicatorConstraintId>()) {
@@ -489,7 +482,7 @@ absl::Status ModelStorage::ApplyUpdateProto(
id.value(), constraint_data(id).name));
}
RETURN_IF_ERROR(summary.indicator_constraints.SetNextFreeId(
indicator_constraints_.next_id().value()));
copyable_data_.indicator_constraints.next_id().value()));
RETURN_IF_ERROR(ValidateModelUpdate(update_proto, summary))
<< "update not valid";
@@ -556,15 +549,15 @@ absl::Status ModelStorage::ApplyUpdateProto(
AddAuxiliaryObjectives(
update_proto.auxiliary_objectives_updates().new_objectives());
AddLinearConstraints(update_proto.new_linear_constraints());
quadratic_constraints_.AddConstraints(
copyable_data_.quadratic_constraints.AddConstraints(
update_proto.quadratic_constraint_updates().new_constraints());
soc_constraints_.AddConstraints(
copyable_data_.soc_constraints.AddConstraints(
update_proto.second_order_cone_constraint_updates().new_constraints());
sos1_constraints_.AddConstraints(
copyable_data_.sos1_constraints.AddConstraints(
update_proto.sos1_constraint_updates().new_constraints());
sos2_constraints_.AddConstraints(
copyable_data_.sos2_constraints.AddConstraints(
update_proto.sos2_constraint_updates().new_constraints());
indicator_constraints_.AddConstraints(
copyable_data_.indicator_constraints.AddConstraints(
update_proto.indicator_constraint_updates().new_constraints());
// Update the primary objective.

View File

@@ -171,7 +171,6 @@ class ModelStorage {
inline explicit ModelStorage(absl::string_view model_name = "",
absl::string_view primary_objective_name = "");
ModelStorage(const ModelStorage&) = delete;
ModelStorage& operator=(const ModelStorage&) = delete;
// Returns a clone of the model, optionally changing model's name.
@@ -183,7 +182,7 @@ class ModelStorage {
std::unique_ptr<ModelStorage> Clone(
std::optional<absl::string_view> new_name = std::nullopt) const;
inline const std::string& name() const { return name_; }
inline const std::string& name() const { return copyable_data_.name; }
//////////////////////////////////////////////////////////////////////////////
// Variables
@@ -692,6 +691,30 @@ class ModelStorage {
dirty_indicator_constraints;
};
// All data that is copied (by the C++ default copy constructor) when using
// Clone().
struct CopyableData {
CopyableData(const absl::string_view model_name,
const absl::string_view primary_objective_name)
: name(model_name), objectives(/*name=*/primary_objective_name) {}
std::string name;
VariableStorage variables;
ObjectiveStorage objectives;
LinearConstraintStorage linear_constraints;
AtomicConstraintStorage<QuadraticConstraintData> quadratic_constraints;
AtomicConstraintStorage<SecondOrderConeConstraintData> soc_constraints;
AtomicConstraintStorage<Sos1ConstraintData> sos1_constraints;
AtomicConstraintStorage<Sos2ConstraintData> sos2_constraints;
AtomicConstraintStorage<IndicatorConstraintData> indicator_constraints;
};
// Private copy constructor that copies only copyable_data_, not
// update_trackers_. It is used internally by Clone().
ModelStorage(const ModelStorage& other)
: copyable_data_(other.copyable_data_) {}
auto UpdateAndGetVariableDiffs() {
return MakeUpdateDataFieldRange<&UpdateTrackerData::dirty_variables>(
update_trackers_.GetUpdatedTrackers());
@@ -745,18 +768,7 @@ class ModelStorage {
template <typename ConstraintData>
const AtomicConstraintStorage<ConstraintData>& constraint_storage() const;
std::string name_;
VariableStorage variables_;
ObjectiveStorage objectives_;
LinearConstraintStorage linear_constraints_;
AtomicConstraintStorage<QuadraticConstraintData> quadratic_constraints_;
AtomicConstraintStorage<SecondOrderConeConstraintData> soc_constraints_;
AtomicConstraintStorage<Sos1ConstraintData> sos1_constraints_;
AtomicConstraintStorage<Sos2ConstraintData> sos2_constraints_;
AtomicConstraintStorage<IndicatorConstraintData> indicator_constraints_;
CopyableData copyable_data_;
UpdateTrackers<UpdateTrackerData> update_trackers_;
};
@@ -768,7 +780,8 @@ class ModelStorage {
ModelStorage::ModelStorage(const absl::string_view model_name,
const absl::string_view primary_objective_name)
: name_(model_name), objectives_(primary_objective_name) {}
: copyable_data_(/*model_name=*/model_name,
/*primary_objective_name=*/primary_objective_name) {}
////////////////////////////////////////////////////////////////////////////////
// Variables
@@ -780,34 +793,37 @@ VariableId ModelStorage::AddVariable(absl::string_view name) {
}
double ModelStorage::variable_lower_bound(const VariableId id) const {
return variables_.lower_bound(id);
return copyable_data_.variables.lower_bound(id);
}
double ModelStorage::variable_upper_bound(const VariableId id) const {
return variables_.upper_bound(id);
return copyable_data_.variables.upper_bound(id);
}
bool ModelStorage::is_variable_integer(VariableId id) const {
return variables_.is_integer(id);
return copyable_data_.variables.is_integer(id);
}
const std::string& ModelStorage::variable_name(const VariableId id) const {
return variables_.name(id);
return copyable_data_.variables.name(id);
}
void ModelStorage::set_variable_lower_bound(const VariableId id,
const double lower_bound) {
variables_.set_lower_bound(id, lower_bound, UpdateAndGetVariableDiffs());
copyable_data_.variables.set_lower_bound(id, lower_bound,
UpdateAndGetVariableDiffs());
}
void ModelStorage::set_variable_upper_bound(const VariableId id,
const double upper_bound) {
variables_.set_upper_bound(id, upper_bound, UpdateAndGetVariableDiffs());
copyable_data_.variables.set_upper_bound(id, upper_bound,
UpdateAndGetVariableDiffs());
}
void ModelStorage::set_variable_is_integer(const VariableId id,
const bool is_integer) {
variables_.set_integer(id, is_integer, UpdateAndGetVariableDiffs());
copyable_data_.variables.set_integer(id, is_integer,
UpdateAndGetVariableDiffs());
}
void ModelStorage::set_variable_as_integer(VariableId id) {
@@ -818,18 +834,20 @@ void ModelStorage::set_variable_as_continuous(VariableId id) {
set_variable_is_integer(id, false);
}
int ModelStorage::num_variables() const { return variables_.size(); }
int ModelStorage::num_variables() const {
return static_cast<int>(copyable_data_.variables.size());
}
VariableId ModelStorage::next_variable_id() const {
return variables_.next_id();
return copyable_data_.variables.next_id();
}
void ModelStorage::ensure_next_variable_id_at_least(const VariableId id) {
variables_.ensure_next_id_at_least(id);
copyable_data_.variables.ensure_next_id_at_least(id);
}
bool ModelStorage::has_variable(const VariableId id) const {
return variables_.contains(id);
return copyable_data_.variables.contains(id);
}
////////////////////////////////////////////////////////////////////////////////
@@ -843,46 +861,46 @@ LinearConstraintId ModelStorage::AddLinearConstraint(absl::string_view name) {
double ModelStorage::linear_constraint_lower_bound(
const LinearConstraintId id) const {
return linear_constraints_.lower_bound(id);
return copyable_data_.linear_constraints.lower_bound(id);
}
double ModelStorage::linear_constraint_upper_bound(
const LinearConstraintId id) const {
return linear_constraints_.upper_bound(id);
return copyable_data_.linear_constraints.upper_bound(id);
}
const std::string& ModelStorage::linear_constraint_name(
const LinearConstraintId id) const {
return linear_constraints_.name(id);
return copyable_data_.linear_constraints.name(id);
}
void ModelStorage::set_linear_constraint_lower_bound(
const LinearConstraintId id, const double lower_bound) {
linear_constraints_.set_lower_bound(id, lower_bound,
UpdateAndGetLinearConstraintDiffs());
copyable_data_.linear_constraints.set_lower_bound(
id, lower_bound, UpdateAndGetLinearConstraintDiffs());
}
void ModelStorage::set_linear_constraint_upper_bound(
const LinearConstraintId id, const double upper_bound) {
linear_constraints_.set_upper_bound(id, upper_bound,
UpdateAndGetLinearConstraintDiffs());
copyable_data_.linear_constraints.set_upper_bound(
id, upper_bound, UpdateAndGetLinearConstraintDiffs());
}
int ModelStorage::num_linear_constraints() const {
return linear_constraints_.size();
return static_cast<int>(copyable_data_.linear_constraints.size());
}
LinearConstraintId ModelStorage::next_linear_constraint_id() const {
return linear_constraints_.next_id();
return copyable_data_.linear_constraints.next_id();
}
void ModelStorage::ensure_next_linear_constraint_id_at_least(
LinearConstraintId id) {
linear_constraints_.ensure_next_id_at_least(id);
copyable_data_.linear_constraints.ensure_next_id_at_least(id);
}
bool ModelStorage::has_linear_constraint(const LinearConstraintId id) const {
return linear_constraints_.contains(id);
return copyable_data_.linear_constraints.contains(id);
}
////////////////////////////////////////////////////////////////////////////////
@@ -891,34 +909,35 @@ bool ModelStorage::has_linear_constraint(const LinearConstraintId id) const {
double ModelStorage::linear_constraint_coefficient(
LinearConstraintId constraint, VariableId variable) const {
return linear_constraints_.matrix().get(constraint, variable);
return copyable_data_.linear_constraints.matrix().get(constraint, variable);
}
bool ModelStorage::is_linear_constraint_coefficient_nonzero(
LinearConstraintId constraint, VariableId variable) const {
return linear_constraints_.matrix().contains(constraint, variable);
return copyable_data_.linear_constraints.matrix().contains(constraint,
variable);
}
void ModelStorage::set_linear_constraint_coefficient(
const LinearConstraintId constraint, const VariableId variable,
const double value) {
linear_constraints_.set_term(constraint, variable, value,
UpdateAndGetLinearConstraintDiffs());
copyable_data_.linear_constraints.set_term(
constraint, variable, value, UpdateAndGetLinearConstraintDiffs());
}
std::vector<std::tuple<LinearConstraintId, VariableId, double>>
std::vector<std::tuple<LinearConstraintId, VariableId, double> >
ModelStorage::linear_constraint_matrix() const {
return linear_constraints_.matrix().Terms();
return copyable_data_.linear_constraints.matrix().Terms();
}
std::vector<VariableId> ModelStorage::variables_in_linear_constraint(
LinearConstraintId constraint) const {
return linear_constraints_.matrix().row(constraint);
return copyable_data_.linear_constraints.matrix().row(constraint);
}
std::vector<LinearConstraintId> ModelStorage::linear_constraints_with_variable(
VariableId variable) const {
return linear_constraints_.matrix().column(variable);
return copyable_data_.linear_constraints.matrix().column(variable);
}
////////////////////////////////////////////////////////////////////////////////
@@ -926,47 +945,49 @@ std::vector<LinearConstraintId> ModelStorage::linear_constraints_with_variable(
////////////////////////////////////////////////////////////////////////////////
bool ModelStorage::is_maximize(const ObjectiveId id) const {
return objectives_.maximize(id);
return copyable_data_.objectives.maximize(id);
}
int64_t ModelStorage::objective_priority(const ObjectiveId id) const {
return objectives_.priority(id);
return copyable_data_.objectives.priority(id);
}
double ModelStorage::objective_offset(const ObjectiveId id) const {
return objectives_.offset(id);
return copyable_data_.objectives.offset(id);
}
double ModelStorage::linear_objective_coefficient(
const ObjectiveId id, const VariableId variable) const {
return objectives_.linear_term(id, variable);
return copyable_data_.objectives.linear_term(id, variable);
}
double ModelStorage::quadratic_objective_coefficient(
const ObjectiveId id, const VariableId first_variable,
const VariableId second_variable) const {
return objectives_.quadratic_term(id, first_variable, second_variable);
return copyable_data_.objectives.quadratic_term(id, first_variable,
second_variable);
}
bool ModelStorage::is_linear_objective_coefficient_nonzero(
const ObjectiveId id, const VariableId variable) const {
return objectives_.linear_terms(id).contains(variable);
return copyable_data_.objectives.linear_terms(id).contains(variable);
}
bool ModelStorage::is_quadratic_objective_coefficient_nonzero(
const ObjectiveId id, const VariableId first_variable,
const VariableId second_variable) const {
return objectives_.quadratic_terms(id).get(first_variable, second_variable) !=
0.0;
return copyable_data_.objectives.quadratic_terms(id).get(
first_variable, second_variable) != 0.0;
}
const std::string& ModelStorage::objective_name(const ObjectiveId id) const {
return objectives_.name(id);
return copyable_data_.objectives.name(id);
}
void ModelStorage::set_is_maximize(const ObjectiveId id,
const bool is_maximize) {
objectives_.set_maximize(id, is_maximize, UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.set_maximize(id, is_maximize,
UpdateAndGetObjectiveDiffs());
}
void ModelStorage::set_maximize(const ObjectiveId id) {
@@ -979,49 +1000,50 @@ void ModelStorage::set_minimize(const ObjectiveId id) {
void ModelStorage::set_objective_priority(const ObjectiveId id,
const int64_t value) {
objectives_.set_priority(id, value, UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.set_priority(id, value,
UpdateAndGetObjectiveDiffs());
}
void ModelStorage::set_objective_offset(const ObjectiveId id,
const double value) {
objectives_.set_offset(id, value, UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.set_offset(id, value, UpdateAndGetObjectiveDiffs());
}
void ModelStorage::set_linear_objective_coefficient(const ObjectiveId id,
const VariableId variable,
const double value) {
objectives_.set_linear_term(id, variable, value,
UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.set_linear_term(id, variable, value,
UpdateAndGetObjectiveDiffs());
}
void ModelStorage::set_quadratic_objective_coefficient(
const ObjectiveId id, const VariableId first_variable,
const VariableId second_variable, const double value) {
objectives_.set_quadratic_term(id, first_variable, second_variable, value,
UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.set_quadratic_term(
id, first_variable, second_variable, value, UpdateAndGetObjectiveDiffs());
}
void ModelStorage::clear_objective(const ObjectiveId id) {
objectives_.Clear(id, UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.Clear(id, UpdateAndGetObjectiveDiffs());
}
const absl::flat_hash_map<VariableId, double>& ModelStorage::linear_objective(
const ObjectiveId id) const {
return objectives_.linear_terms(id);
return copyable_data_.objectives.linear_terms(id);
}
int64_t ModelStorage::num_linear_objective_terms(const ObjectiveId id) const {
return objectives_.linear_terms(id).size();
return copyable_data_.objectives.linear_terms(id).size();
}
int64_t ModelStorage::num_quadratic_objective_terms(
const ObjectiveId id) const {
return objectives_.quadratic_terms(id).nonzeros();
return copyable_data_.objectives.quadratic_terms(id).nonzeros();
}
std::vector<std::tuple<VariableId, VariableId, double>>
std::vector<std::tuple<VariableId, VariableId, double> >
ModelStorage::quadratic_objective_terms(const ObjectiveId id) const {
return objectives_.quadratic_terms(id).Terms();
return copyable_data_.objectives.quadratic_terms(id).Terms();
}
////////////////////////////////////////////////////////////////////////////////
@@ -1030,38 +1052,38 @@ ModelStorage::quadratic_objective_terms(const ObjectiveId id) const {
AuxiliaryObjectiveId ModelStorage::AddAuxiliaryObjective(
const int64_t priority, const absl::string_view name) {
return objectives_.AddAuxiliaryObjective(priority, name);
return copyable_data_.objectives.AddAuxiliaryObjective(priority, name);
}
void ModelStorage::DeleteAuxiliaryObjective(const AuxiliaryObjectiveId id) {
objectives_.Delete(id, UpdateAndGetObjectiveDiffs());
copyable_data_.objectives.Delete(id, UpdateAndGetObjectiveDiffs());
}
int ModelStorage::num_auxiliary_objectives() const {
return static_cast<int>(objectives_.num_auxiliary_objectives());
return static_cast<int>(copyable_data_.objectives.num_auxiliary_objectives());
}
AuxiliaryObjectiveId ModelStorage::next_auxiliary_objective_id() const {
return objectives_.next_id();
return copyable_data_.objectives.next_id();
}
void ModelStorage::ensure_next_auxiliary_objective_id_at_least(
const AuxiliaryObjectiveId id) {
objectives_.ensure_next_id_at_least(id);
copyable_data_.objectives.ensure_next_id_at_least(id);
}
bool ModelStorage::has_auxiliary_objective(
const AuxiliaryObjectiveId id) const {
return objectives_.contains(id);
return copyable_data_.objectives.contains(id);
}
std::vector<AuxiliaryObjectiveId> ModelStorage::AuxiliaryObjectives() const {
return objectives_.AuxiliaryObjectives();
return copyable_data_.objectives.AuxiliaryObjectives();
}
std::vector<AuxiliaryObjectiveId> ModelStorage::SortedAuxiliaryObjectives()
const {
return objectives_.SortedAuxiliaryObjectives();
return copyable_data_.objectives.SortedAuxiliaryObjectives();
}
////////////////////////////////////////////////////////////////////////////////
@@ -1162,13 +1184,13 @@ std::vector<VariableId> ModelStorage::VariablesInConstraint(
template <>
inline AtomicConstraintStorage<QuadraticConstraintData>&
ModelStorage::constraint_storage() {
return quadratic_constraints_;
return copyable_data_.quadratic_constraints;
}
template <>
inline const AtomicConstraintStorage<QuadraticConstraintData>&
ModelStorage::constraint_storage() const {
return quadratic_constraints_;
return copyable_data_.quadratic_constraints;
}
template <>
@@ -1184,13 +1206,13 @@ constexpr typename AtomicConstraintStorage<QuadraticConstraintData>::Diff
template <>
inline AtomicConstraintStorage<SecondOrderConeConstraintData>&
ModelStorage::constraint_storage() {
return soc_constraints_;
return copyable_data_.soc_constraints;
}
template <>
inline const AtomicConstraintStorage<SecondOrderConeConstraintData>&
ModelStorage::constraint_storage() const {
return soc_constraints_;
return copyable_data_.soc_constraints;
}
template <>
@@ -1206,13 +1228,13 @@ constexpr typename AtomicConstraintStorage<SecondOrderConeConstraintData>::Diff
template <>
inline AtomicConstraintStorage<Sos1ConstraintData>&
ModelStorage::constraint_storage() {
return sos1_constraints_;
return copyable_data_.sos1_constraints;
}
template <>
inline const AtomicConstraintStorage<Sos1ConstraintData>&
ModelStorage::constraint_storage() const {
return sos1_constraints_;
return copyable_data_.sos1_constraints;
}
template <>
@@ -1228,13 +1250,13 @@ constexpr typename AtomicConstraintStorage<Sos1ConstraintData>::Diff
template <>
inline AtomicConstraintStorage<Sos2ConstraintData>&
ModelStorage::constraint_storage() {
return sos2_constraints_;
return copyable_data_.sos2_constraints;
}
template <>
inline const AtomicConstraintStorage<Sos2ConstraintData>&
ModelStorage::constraint_storage() const {
return sos2_constraints_;
return copyable_data_.sos2_constraints;
}
template <>
@@ -1250,13 +1272,13 @@ constexpr typename AtomicConstraintStorage<Sos2ConstraintData>::Diff
template <>
inline AtomicConstraintStorage<IndicatorConstraintData>&
ModelStorage::constraint_storage() {
return indicator_constraints_;
return copyable_data_.indicator_constraints;
}
template <>
inline const AtomicConstraintStorage<IndicatorConstraintData>&
ModelStorage::constraint_storage() const {
return indicator_constraints_;
return copyable_data_.indicator_constraints;
}
template <>