math_opt: export from google3

This commit is contained in:
Corentin Le Molgat
2024-04-19 13:40:42 +02:00
parent faced59676
commit e97e98d83d
32 changed files with 139 additions and 95 deletions

View File

@@ -27,7 +27,7 @@ namespace operations_research {
// TODO(b/311704821): this function should not delegate to MPSolver, also true
// for the functions below.
MPSolutionResponse SolveMPModel(LazyMutableCopy<MPModelRequest> request,
SolveInterrupter* interrupter) {
const SolveInterrupter* interrupter) {
MPSolutionResponse response;
if (interrupter != nullptr) {
std::atomic<bool> atomic_bool = false;

View File

@@ -43,7 +43,7 @@ namespace operations_research {
* MPSOLVER_INCOMPATIBLE_OPTIONS error.
*/
MPSolutionResponse SolveMPModel(LazyMutableCopy<MPModelRequest> request,
SolveInterrupter* interrupter = nullptr);
const SolveInterrupter* interrupter = nullptr);
bool SolverTypeSupportsInterruption(MPModelRequest::SolverType solver);

View File

@@ -99,7 +99,7 @@ class Solver {
// An optional interrupter that the solver can use to interrupt the solve
// early.
SolveInterrupter* interrupter = nullptr;
const SolveInterrupter* interrupter = nullptr;
};
// Arguments used when calling ComputeInfeasibleSubsystem().
@@ -115,7 +115,7 @@ class Solver {
// An optional interrupter that the solver can use to interrupt the solve
// early.
SolveInterrupter* interrupter = nullptr;
const SolveInterrupter* interrupter = nullptr;
};
// A shortcut for calling Solver::New() and then Solver::Solve().

View File

@@ -122,7 +122,7 @@ class SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) = 0;
const SolveInterrupter* interrupter) = 0;
// Updates the model to solve and returns true, or returns false if this
// update is not supported.
@@ -151,7 +151,7 @@ class SolverInterface {
virtual absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) = 0;
const SolveInterrupter* interrupter) = 0;
};
class AllSolversRegistry {

View File

@@ -59,7 +59,7 @@ struct ComputeInfeasibleSubsystemArguments {
// ComputeInfeasibleSubsystem(model, SolverType::kGurobi,
// { .interrupter = interrupter.get() });
//
SolveInterrupter* interrupter = nullptr;
const SolveInterrupter* interrupter = nullptr;
};
} // namespace operations_research::math_opt

View File

@@ -82,7 +82,7 @@ struct SolveArguments {
// Solve(model, SolverType::kGlop,
// { .interrupter = interrupter.get() });
//
SolveInterrupter* interrupter = nullptr;
const SolveInterrupter* interrupter = nullptr;
// Returns a failure if the referenced variables and constraints don't belong
// to the input expected_storage (which must not be nullptr). Also returns a

View File

@@ -28,6 +28,9 @@ SolverResourcesProto SolverResources::Proto() const {
if (cpu.has_value()) {
ret.set_cpu(cpu.value());
}
if (ram.has_value()) {
ret.set_ram(ram.value());
}
return ret;
}
@@ -37,6 +40,9 @@ absl::StatusOr<SolverResources> SolverResources::FromProto(
if (proto.has_cpu()) {
ret.cpu = proto.cpu();
}
if (proto.has_ram()) {
ret.ram = proto.ram();
}
return ret;
}

View File

@@ -58,6 +58,10 @@ struct SolverResources {
// should also be left unset.
std::optional<double> cpu;
// The limit of RAM for the solve in bytes. Must be finite and >=1.0 (even
// though it should in practice be much larger).
std::optional<double> ram;
SolverResourcesProto Proto() const;
static absl::StatusOr<SolverResources> FromProto(
const SolverResourcesProto& proto);

View File

@@ -144,9 +144,11 @@ from ortools.math_opt.python.solution import Basis
from ortools.math_opt.python.solution import BasisStatus
from ortools.math_opt.python.solution import DualRay
from ortools.math_opt.python.solution import DualSolution
from ortools.math_opt.python.solution import optional_solution_status_to_proto
from ortools.math_opt.python.solution import parse_basis
from ortools.math_opt.python.solution import parse_dual_ray
from ortools.math_opt.python.solution import parse_dual_solution
from ortools.math_opt.python.solution import parse_optional_solution_status
from ortools.math_opt.python.solution import parse_primal_ray
from ortools.math_opt.python.solution import parse_primal_solution
from ortools.math_opt.python.solution import parse_solution

View File

@@ -100,9 +100,6 @@ class ModelParametersTest(compare_proto.MathOptProtoAssertions, absltest.TestCas
expected.initial_basis.variable_status.values.append(
solution_pb2.BASIS_STATUS_AT_UPPER_BOUND
)
expected.initial_basis.basic_dual_feasibility = (
solution_pb2.SOLUTION_STATUS_UNDETERMINED
)
self.assert_protos_equiv(expected, actual)

View File

@@ -58,6 +58,24 @@ class SolutionStatus(enum.Enum):
INFEASIBLE = solution_pb2.SOLUTION_STATUS_INFEASIBLE
def parse_optional_solution_status(
proto: solution_pb2.SolutionStatusProto,
) -> Optional[SolutionStatus]:
"""Converts a proto SolutionStatus to an optional Python SolutionStatus."""
return (
None
if proto == solution_pb2.SOLUTION_STATUS_UNSPECIFIED
else SolutionStatus(proto)
)
def optional_solution_status_to_proto(
status: Optional[SolutionStatus],
) -> solution_pb2.SolutionStatusProto:
"""Converts an optional Python SolutionStatus to a proto SolutionStatus."""
return solution_pb2.SOLUTION_STATUS_UNSPECIFIED if status is None else status.value
@dataclasses.dataclass
class PrimalSolution:
"""A solution to the optimization problem in a Model.
@@ -312,12 +330,13 @@ class Basis:
For two-sided LPs it may be different in some edge cases (e.g. incomplete
solves with primal simplex). For more details see
go/mathopt-basis-advanced#dualfeasibility. If you are providing a starting
basis via ModelSolveParameters.initial_basis, this value is ignored. It is
only relevant for the basis returned by Solution.basis. This is an
advanced status. For single-sided LPs it should be equal to the
feasibility status of the associated dual solution. For two-sided LPs it
may be different in some edge cases (e.g. incomplete solves with primal
simplex). For more details see go/mathopt-basis-advanced#dualfeasibility.
basis via ModelSolveParameters.initial_basis, this value is ignored and
can be None. It is only relevant for the basis returned by Solution.basis,
and it is never None when returned from solve(). This is an advanced
status. For single-sided LPs it should be equal to the feasibility status
of the associated dual solution. For two-sided LPs it may be different in
some edge cases (e.g. incomplete solves with primal simplex). For more
details see go/mathopt-basis-advanced#dualfeasibility.
"""
variable_status: Dict[model.Variable, BasisStatus] = dataclasses.field(
@@ -326,7 +345,7 @@ class Basis:
constraint_status: Dict[model.LinearConstraint, BasisStatus] = dataclasses.field(
default_factory=dict
)
basic_dual_feasibility: SolutionStatus = SolutionStatus.UNDETERMINED
basic_dual_feasibility: Optional[SolutionStatus] = None
def to_proto(self) -> solution_pb2.BasisProto:
"""Returns an equivalent proto for the basis."""
@@ -335,7 +354,9 @@ class Basis:
constraint_status=_to_sparse_basis_status_vector_proto(
self.constraint_status
),
basic_dual_feasibility=self.basic_dual_feasibility.value,
basic_dual_feasibility=optional_solution_status_to_proto(
self.basic_dual_feasibility
),
)
@@ -354,10 +375,9 @@ def parse_basis(proto: solution_pb2.BasisProto, mod: model.Model) -> Basis:
result.constraint_status[mod.get_linear_constraint(cid)] = BasisStatus(
status_proto
)
status_proto = proto.basic_dual_feasibility
if status_proto == solution_pb2.SOLUTION_STATUS_UNSPECIFIED:
raise ValueError("Basic dual feasibility status should not be UNSPECIFIED")
result.basic_dual_feasibility = SolutionStatus(status_proto)
result.basic_dual_feasibility = parse_optional_solution_status(
proto.basic_dual_feasibility
)
return result

View File

@@ -19,6 +19,18 @@ from ortools.math_opt.python import solution
from ortools.math_opt.python.testing import compare_proto
class SolutionStatusTest(absltest.TestCase):
def test_optional_status_round_trip(self):
for status in solution_pb2.SolutionStatusProto.values():
self.assertEqual(
status,
solution.optional_solution_status_to_proto(
solution.parse_optional_solution_status(status)
),
)
class ParsePrimalSolutionTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
def test_empty_primal_solution_proto_round_trip(self) -> None:
@@ -164,13 +176,11 @@ class BasisTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
empty_basis = solution.Basis()
empty_proto = empty_basis.to_proto()
expected_proto = solution_pb2.BasisProto()
expected_proto.basic_dual_feasibility = (
solution_pb2.SOLUTION_STATUS_UNDETERMINED
)
self.assert_protos_equiv(expected_proto, empty_proto)
round_trip_basis = solution.parse_basis(empty_proto, mod)
self.assertEmpty(round_trip_basis.constraint_status)
self.assertEmpty(round_trip_basis.variable_status)
self.assertIsNone(round_trip_basis.basic_dual_feasibility)
def test_basis_proto_round_trip(self) -> None:
mod = model.Model(name="test_model")
@@ -252,24 +262,9 @@ class BasisTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):
def test_basic_dual_feasibility_unspecified(self) -> None:
mod = model.Model(name="test_model")
mod.add_binary_variable(name="x")
mod.add_binary_variable(name="y")
mod.add_linear_constraint(lb=0.0, ub=1.0, name="c")
mod.add_linear_constraint(lb=0.0, ub=1.0, name="d")
basis_proto = solution_pb2.BasisProto()
basis_proto.constraint_status.ids[:] = [0, 1]
basis_proto.constraint_status.values[:] = [
solution_pb2.BASIS_STATUS_BASIC,
solution_pb2.BASIS_STATUS_AT_UPPER_BOUND,
]
basis_proto.variable_status.ids[:] = [0, 1]
basis_proto.variable_status.values[:] = [
solution_pb2.BASIS_STATUS_AT_UPPER_BOUND,
solution_pb2.BASIS_STATUS_BASIC,
]
basis_proto.basic_dual_feasibility = solution_pb2.SOLUTION_STATUS_UNSPECIFIED
with self.assertRaisesRegex(ValueError, "Basic dual feasibility.*UNSPECIFIED"):
solution.parse_basis(basis_proto, mod)
basis = solution.parse_basis(basis_proto, mod)
self.assertIsNone(basis.basic_dual_feasibility)
class ParseSolutionTest(compare_proto.MathOptProtoAssertions, absltest.TestCase):

View File

@@ -37,8 +37,9 @@ class SolverResources:
MOE:begin_intracomment_strip
The go/uoss server will use these parameters to do a bin-packing of all
requests. They are generally used as soft-limits though instead of
hard-limits and a solve may be able to consume more resources than requested.
requests. Parameter cpu is a soft-limit, the solve may still be able to use
more CPUs. The ram parameter is an hard-limit, an out-of-memory error will
occur if the solve attempts to use more memory.
MOE:end_intracomment_strip
@@ -58,9 +59,12 @@ class SolverResources:
better to consult each solver documentation to set this parameter. Note
that if the SolveParameters.threads is not set then this parameter should
also be left unset.
ram: The limit of RAM for the solve in bytes. Must be finite and >=1.0 (even
though it should in practice be much larger).
"""
cpu: Optional[float] = None
ram: Optional[float] = None
def to_proto(self) -> rpc_pb2.SolverResourcesProto:
return rpc_pb2.SolverResourcesProto(cpu=self.cpu)
return rpc_pb2.SolverResourcesProto(cpu=self.cpu, ram=self.ram)

View File

@@ -34,6 +34,12 @@ class SolverResourcesTest(compare_proto.MathOptProtoAssertions, absltest.TestCas
rpc_pb2.SolverResourcesProto(cpu=3.5),
)
def test_to_proto_with_ram(self):
self.assert_protos_equiv(
solver_resources.SolverResources(ram=50 * 1024 * 1024).to_proto(),
rpc_pb2.SolverResourcesProto(ram=50 * 1024 * 1024),
)
if __name__ == "__main__":
absltest.main()

View File

@@ -37,7 +37,6 @@ option java_multiple_files = true;
// When using SolveService.StreamSolve these hints are used to dimension the
// resources available during the execution of every action; thus it is
// recommended to set them.
//
message SolverResourcesProto {
// The number of solver threads that are expected to actually execute in
// parallel. Must be finite and >0.0.
@@ -61,6 +60,10 @@ message SolverResourcesProto {
// Note that if the SolveParametersProto.threads is not set then this
// parameter should also be left unset.
optional double cpu = 1;
// The limit of RAM for the solve in bytes. Must be finite and >=1.0 (even
// though it should in practice be much larger).
optional double ram = 2;
}
// Request for a unary remote solve in MathOpt.

View File

@@ -338,7 +338,7 @@ absl::StatusOr<SolveResultProto> CpSatSolver::Solve(
const ModelSolveParametersProto& model_parameters,
const MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, const Callback cb,
SolveInterrupter* const interrupter) {
const SolveInterrupter* const interrupter) {
const absl::Time start = absl::Now();
RETURN_IF_ERROR(CheckRegisteredCallbackEvents(
@@ -528,7 +528,7 @@ InvertedBounds CpSatSolver::ListInvertedBounds() const {
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
CpSatSolver::ComputeInfeasibleSubsystem(const SolveParametersProto&,
MessageCallback,
SolveInterrupter* const) {
const SolveInterrupter*) {
return absl::UnimplementedError(
"CPSAT does not provide a method to compute an infeasible subsystem");
}

View File

@@ -46,12 +46,12 @@ class CpSatSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
private:
CpSatSolver(MPModelProto cp_sat_model, std::vector<int64_t> variable_ids,

View File

@@ -765,7 +765,7 @@ absl::StatusOr<SolveResultProto> GlopSolver::Solve(
const ModelSolveParametersProto& model_parameters,
const MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, const Callback,
SolveInterrupter* const interrupter) {
const SolveInterrupter* const interrupter) {
RETURN_IF_ERROR(CheckRegisteredCallbackEvents(callback_registration,
/*supported_events=*/{}));
@@ -883,7 +883,8 @@ absl::StatusOr<bool> GlopSolver::Update(const ModelUpdateProto& model_update) {
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
GlopSolver::ComputeInfeasibleSubsystem(const SolveParametersProto&,
MessageCallback, SolveInterrupter*) {
MessageCallback,
const SolveInterrupter*) {
return absl::UnimplementedError(
"GLOP does not implement a method to compute an infeasible subsystem");
}

View File

@@ -53,12 +53,12 @@ class GlopSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
// Returns the merged parameters and a list of warnings from any parameter
// settings that are invalid for this solver.

View File

@@ -490,7 +490,7 @@ absl::Status SetLPParameters(const SolveParametersProto& parameters,
class MipCallbackData {
public:
explicit MipCallbackData(SolveInterrupter* const interrupter)
explicit MipCallbackData(const SolveInterrupter* const interrupter)
: interrupter_(interrupter) {}
void Callback(glp_tree* const tree) {
@@ -540,7 +540,7 @@ class MipCallbackData {
private:
// Optional interrupter.
SolveInterrupter* const interrupter_;
const SolveInterrupter* const interrupter_;
// Set to true if glp_ios_terminate() has been called due to the interrupter.
std::atomic<bool> interrupted_by_interrupter_ = false;
@@ -1059,7 +1059,7 @@ absl::StatusOr<SolveResultProto> GlpkSolver::Solve(
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration,
const Callback /*cb*/, SolveInterrupter* const interrupter) {
const Callback /*cb*/, const SolveInterrupter* const interrupter) {
RETURN_IF_ERROR(CheckCurrentThread());
const absl::Time start = absl::Now();
@@ -1804,9 +1804,9 @@ std::optional<SolveResultProto> GlpkSolver::EmptyIntegerBoundsResult() {
}
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
GlpkSolver::ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* const interrupter) {
GlpkSolver::ComputeInfeasibleSubsystem(
const SolveParametersProto& parameters, MessageCallback message_cb,
const SolveInterrupter* const interrupter) {
return absl::UnimplementedError(
"GLPK does not provide a method to compute an infeasible subsystem");
}

View File

@@ -55,12 +55,12 @@ class GlpkSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
private:
// The columns of the GPLK problem.

View File

@@ -1013,7 +1013,7 @@ absl::StatusOr<SolveResultProto> GScipSolver::Solve(
const ModelSolveParametersProto& model_parameters,
const MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* const interrupter) {
const SolveInterrupter* const interrupter) {
const absl::Time start = absl::Now();
GScip::Interrupter gscip_interrupter;
@@ -1352,7 +1352,7 @@ absl::StatusOr<bool> GScipSolver::Update(const ModelUpdateProto& model_update) {
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
GScipSolver::ComputeInfeasibleSubsystem(const SolveParametersProto&,
MessageCallback,
SolveInterrupter* const) {
const SolveInterrupter*) {
return absl::UnimplementedError(
"SCIP does not provide a method to compute an infeasible subsystem");
}

View File

@@ -58,12 +58,12 @@ class GScipSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
// Returns the merged parameters and a list of warnings for unsupported
// parameters.

View File

@@ -2893,7 +2893,7 @@ absl::StatusOr<SolveResultProto> GurobiSolver::Solve(
const ModelSolveParametersProto& model_parameters,
const MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, const Callback cb,
SolveInterrupter* const interrupter) {
const SolveInterrupter* const interrupter) {
const absl::Time start = absl::Now();
// Need to run GRBupdatemodel before:
@@ -3020,9 +3020,9 @@ absl::StatusOr<SolveResultProto> GurobiSolver::Solve(
// TODO(b/277339044): Remove code duplication with GurobiSolver::Solve().
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
GurobiSolver::ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* const interrupter) {
GurobiSolver::ComputeInfeasibleSubsystem(
const SolveParametersProto& parameters, MessageCallback message_cb,
const SolveInterrupter* const interrupter) {
const absl::Time start = absl::Now();
// Need to run GRBupdatemodel before:

View File

@@ -60,12 +60,12 @@ class GurobiSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
private:
struct GurobiCallbackData {
@@ -352,7 +352,7 @@ class GurobiSolver : public SolverInterface {
absl::StatusOr<std::unique_ptr<GurobiCallbackData>> RegisterCallback(
const CallbackRegistrationProto& registration, Callback cb,
MessageCallback message_cb, absl::Time start,
SolveInterrupter* interrupter);
SolveInterrupter* local_interrupter);
// Returns the ids of variables and linear constraints with inverted bounds.
absl::StatusOr<InvertedBounds> ListInvertedBounds() const;

View File

@@ -907,7 +907,7 @@ absl::StatusOr<SolveResultProto> HighsSolver::Solve(
const SolveParametersProto& parameters,
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb, const CallbackRegistrationProto&, Callback,
SolveInterrupter* const) {
const SolveInterrupter* const) {
const absl::Time start = absl::Now();
auto set_solve_time = [&start](SolveResultProto& result) -> absl::Status {
const absl::Duration solve_time = absl::Now() - start;
@@ -1008,7 +1008,7 @@ absl::StatusOr<bool> HighsSolver::Update(const ModelUpdateProto&) {
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
HighsSolver::ComputeInfeasibleSubsystem(const SolveParametersProto&,
MessageCallback,
SolveInterrupter* const) {
const SolveInterrupter*) {
return absl::UnimplementedError(
"HiGHS does not provide a method to compute an infeasible subsystem");
}

View File

@@ -48,12 +48,12 @@ class HighsSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
private:
struct SolutionClaims {

View File

@@ -333,7 +333,7 @@ absl::StatusOr<SolveResultProto> PdlpSolver::Solve(
const ModelSolveParametersProto& model_parameters,
const MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, const Callback,
SolveInterrupter* const interrupter) {
const SolveInterrupter* const interrupter) {
RETURN_IF_ERROR(CheckRegisteredCallbackEvents(callback_registration,
/*supported_events=*/{}));
@@ -376,7 +376,7 @@ absl::StatusOr<bool> PdlpSolver::Update(const ModelUpdateProto&) {
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
PdlpSolver::ComputeInfeasibleSubsystem(const SolveParametersProto&,
MessageCallback,
SolveInterrupter* const) {
const SolveInterrupter*) {
return absl::UnimplementedError(
"PDLP does not provide a method to compute an infeasible subsystem");
}

View File

@@ -43,12 +43,12 @@ class PdlpSolver : public SolverInterface {
const ModelSolveParametersProto& model_parameters,
MessageCallback message_cb,
const CallbackRegistrationProto& callback_registration, Callback cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
absl::StatusOr<bool> Update(const ModelUpdateProto& model_update) override;
absl::StatusOr<ComputeInfeasibleSubsystemResultProto>
ComputeInfeasibleSubsystem(const SolveParametersProto& parameters,
MessageCallback message_cb,
SolveInterrupter* interrupter) override;
const SolveInterrupter* interrupter) override;
// Returns the merged parameters and a list of warnings.
static absl::StatusOr<pdlp::PrimalDualHybridGradientParams> MergeParameters(

View File

@@ -278,7 +278,7 @@ absl::Status PrintSummary(const Model& model, const SolveResult& result,
absl::StatusOr<SolveResult> LocalOrRemoteSolve(
const Model& model, const SolverType solver_type,
const SolveParameters& params, const ModelSolveParameters& model_params,
MessageCallback msg_cb, SolveInterrupter* interrupter) {
MessageCallback msg_cb, const SolveInterrupter* const interrupter) {
if (absl::GetFlag(FLAGS_remote)) {
return absl::UnimplementedError("remote not yet supported.");
} else {

View File

@@ -52,7 +52,7 @@ void SolveInterrupter::Interrupt() {
}
SolveInterrupter::CallbackId SolveInterrupter::AddInterruptionCallback(
Callback callback) {
Callback callback) const {
const absl::MutexLock lock(&mutex_);
// We must make this call while holding the lock since we want to be sure that
@@ -73,13 +73,14 @@ SolveInterrupter::CallbackId SolveInterrupter::AddInterruptionCallback(
return id;
}
void SolveInterrupter::RemoveInterruptionCallback(CallbackId id) {
void SolveInterrupter::RemoveInterruptionCallback(CallbackId id) const {
const absl::MutexLock lock(&mutex_);
CHECK_EQ(callbacks_.erase(id), 1) << "unregistered callback id: " << id;
}
ScopedSolveInterrupterCallback::ScopedSolveInterrupterCallback(
SolveInterrupter* const interrupter, SolveInterrupter::Callback callback)
const SolveInterrupter* const interrupter,
SolveInterrupter::Callback callback)
: interrupter_(interrupter),
callback_id_(
interrupter != nullptr

View File

@@ -72,7 +72,11 @@ class SolveInterrupter {
// The callback function can't make calls to AddInterruptionCallback(),
// RemoveInterruptionCallback() and Interrupt(). This would result is a
// deadlock. Calling IsInterrupted() is fine though.
CallbackId AddInterruptionCallback(Callback callback);
//
// This method is `const` since it does not modify the state of the
// interrupter (the result of IsInterrupted()). This enables passing a
// const-ref to solvers, making sure they can't call Interrupt() by mistake.
CallbackId AddInterruptionCallback(Callback callback) const;
// Unregisters a callback previously registered. It fails (with a CHECK) if
// the callback was already unregistered or unkonwn. After this calls returns,
@@ -80,7 +84,7 @@ class SolveInterrupter {
//
// This function can't be called from a callback since this would result in a
// deadlock.
void RemoveInterruptionCallback(CallbackId id);
void RemoveInterruptionCallback(CallbackId id) const;
private:
// This atomic must never be reset to false!
@@ -88,21 +92,22 @@ class SolveInterrupter {
// The mutex_ should be held when setting it to true.
std::atomic<bool> interrupted_ = false;
absl::Mutex mutex_;
mutable absl::Mutex mutex_;
// The id to use for the next registered callback.
CallbackId next_callback_id_ ABSL_GUARDED_BY(mutex_) = {};
mutable CallbackId next_callback_id_ ABSL_GUARDED_BY(mutex_) = {};
// The list of callbacks. We use a linked_hash_map to make sure the order of
// calls to callback when the interrupter is triggered is stable.
gtl::linked_hash_map<CallbackId, Callback> callbacks_ ABSL_GUARDED_BY(mutex_);
mutable gtl::linked_hash_map<CallbackId, Callback> callbacks_
ABSL_GUARDED_BY(mutex_);
};
// Class implementing RAII for interruption callbacks.
//
// Usage:
//
// SolveInterrupter* const interrupter = ...;
// const SolveInterrupter* const interrupter = ...;
// {
// const ScopedSolveInterrupterCallback scoped_intr_cb(interrupter, [](){
// // Do something when/if interrupter is not nullptr and is triggered.
@@ -117,7 +122,7 @@ class ScopedSolveInterrupterCallback {
public:
// Adds a callback to the interrupter if it is not nullptr. Does nothing when
// interrupter is nullptr.
ScopedSolveInterrupterCallback(SolveInterrupter* interrupter,
ScopedSolveInterrupterCallback(const SolveInterrupter* interrupter,
SolveInterrupter::Callback callback);
ScopedSolveInterrupterCallback(const ScopedSolveInterrupterCallback&) =
@@ -134,11 +139,11 @@ class ScopedSolveInterrupterCallback {
void RemoveCallbackIfNecessary();
// Returns the optional interrupter.
SolveInterrupter* interrupter() const { return interrupter_; }
const SolveInterrupter* interrupter() const { return interrupter_; }
private:
// Optional interrupter.
SolveInterrupter* const interrupter_;
const SolveInterrupter* const interrupter_;
// Unset after the callback has been reset.
std::optional<SolveInterrupter::CallbackId> callback_id_;