math_opt: export from google3

This commit is contained in:
Corentin Le Molgat
2024-04-15 17:59:32 +02:00
parent 80e677c19b
commit 039192f29c
13 changed files with 250 additions and 60 deletions

View File

@@ -212,13 +212,19 @@ cc_library(
cc_library(
name = "map_filter",
srcs = ["map_filter.cc"],
hdrs = ["map_filter.h"],
deps = [
":key_types",
":linear_constraint",
":model",
":variable_and_expressions",
"//ortools/base:status_macros",
"//ortools/math_opt:sparse_containers_cc_proto",
"//ortools/math_opt/storage:model_storage",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
],
)

View File

@@ -0,0 +1,66 @@
// Copyright 2010-2024 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/cpp/map_filter.h"
#include <cstdint>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "ortools/base/status_builder.h"
#include "ortools/math_opt/cpp/linear_constraint.h"
#include "ortools/math_opt/cpp/model.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
namespace operations_research::math_opt {
absl::StatusOr<MapFilter<Variable>> VariableFilterFromProto(
const Model& model, const SparseVectorFilterProto& proto) {
MapFilter<Variable> result = {.skip_zero_values = proto.skip_zero_values()};
if (proto.filter_by_ids()) {
absl::flat_hash_set<Variable> filtered;
for (const int64_t id : proto.filtered_ids()) {
if (!model.has_variable(id)) {
return util::InvalidArgumentErrorBuilder()
<< "cannot create MapFilter<Variable> from proto, variable id: "
<< id << " not in model";
}
filtered.insert(model.variable(id));
}
result.filtered_keys = std::move(filtered);
}
return result;
}
absl::StatusOr<MapFilter<LinearConstraint>> LinearConstraintFilterFromProto(
const Model& model, const SparseVectorFilterProto& proto) {
MapFilter<LinearConstraint> result = {.skip_zero_values =
proto.skip_zero_values()};
if (proto.filter_by_ids()) {
absl::flat_hash_set<LinearConstraint> filtered;
for (const int64_t id : proto.filtered_ids()) {
if (!model.has_linear_constraint(id)) {
return util::InvalidArgumentErrorBuilder()
<< "cannot create MapFilter<LinearConstraint> from proto, "
"linear constraint id: "
<< id << " not in model";
}
filtered.insert(model.linear_constraint(id));
}
result.filtered_keys = std::move(filtered);
}
return result;
}
} // namespace operations_research::math_opt

View File

@@ -22,8 +22,12 @@
#include <optional>
#include "absl/algorithm/container.h"
#include "absl/status/statusor.h"
#include "ortools/base/status_macros.h"
#include "ortools/math_opt/cpp/key_types.h"
#include "ortools/math_opt/cpp/linear_constraint.h"
#include "ortools/math_opt/cpp/model.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/sparse_containers.pb.h"
#include "ortools/math_opt/storage/model_storage.h"
@@ -105,6 +109,20 @@ struct MapFilter {
SparseVectorFilterProto Proto() const;
};
// Returns the MapFilter<Variable> equivalent to `proto`.
//
// Requires that (or returns a status error):
// * proto.filtered_ids has elements that are variables in `model`.
absl::StatusOr<MapFilter<Variable>> VariableFilterFromProto(
const Model& model, const SparseVectorFilterProto& proto);
// Returns the MapFilter<LinearConstraint> equivalent to `proto`.
//
// Requires that (or returns a status error):
// * proto.filtered_ids has elements that are linear constraints in `model`.
absl::StatusOr<MapFilter<LinearConstraint>> LinearConstraintFilterFromProto(
const Model& model, const SparseVectorFilterProto& proto);
// Returns a filter that skips all key-value pairs.
//
// This is typically used to disable the dual data in SolveResult when these are

View File

@@ -145,6 +145,21 @@ ObjectiveParametersProto ModelSolveParameters::ObjectiveParameters::Proto()
return params;
}
ModelSolveParameters::ObjectiveParameters
ModelSolveParameters::ObjectiveParameters::FromProto(
const ObjectiveParametersProto& proto) {
ObjectiveParameters result;
if (proto.has_objective_degradation_absolute_tolerance()) {
result.objective_degradation_absolute_tolerance =
proto.objective_degradation_absolute_tolerance();
}
if (proto.has_objective_degradation_relative_tolerance()) {
result.objective_degradation_relative_tolerance =
proto.objective_degradation_relative_tolerance();
}
return result;
}
// TODO: b/315974557 - Return an error if a RepeatedField is too long.
ModelSolveParametersProto ModelSolveParameters::Proto() const {
ModelSolveParametersProto ret;
@@ -152,43 +167,10 @@ ModelSolveParametersProto ModelSolveParameters::Proto() const {
*ret.mutable_dual_values_filter() = dual_values_filter.Proto();
*ret.mutable_reduced_costs_filter() = reduced_costs_filter.Proto();
// TODO(b/183616124): consolidate code. Probably best to add an
// export_to_proto to IdMap
if (initial_basis) {
RepeatedField<int64_t>& constraint_status_ids =
*ret.mutable_initial_basis()
->mutable_constraint_status()
->mutable_ids();
RepeatedField<int>& constraint_status_values =
*ret.mutable_initial_basis()
->mutable_constraint_status()
->mutable_values();
constraint_status_ids.Reserve(
static_cast<int>(initial_basis->constraint_status.size()));
constraint_status_values.Reserve(
static_cast<int>(initial_basis->constraint_status.size()));
for (const LinearConstraint& key :
SortedKeys(initial_basis->constraint_status)) {
constraint_status_ids.Add(key.id());
constraint_status_values.Add(
EnumToProto(initial_basis->constraint_status.at(key)));
}
RepeatedField<int64_t>& variable_status_ids =
*ret.mutable_initial_basis()->mutable_variable_status()->mutable_ids();
RepeatedField<int>& variable_status_values =
*ret.mutable_initial_basis()
->mutable_variable_status()
->mutable_values();
variable_status_ids.Reserve(
static_cast<int>(initial_basis->variable_status.size()));
variable_status_values.Reserve(
static_cast<int>(initial_basis->variable_status.size()));
for (const Variable& key : SortedKeys(initial_basis->variable_status)) {
variable_status_ids.Add(key.id());
variable_status_values.Add(
EnumToProto(initial_basis->variable_status.at(key)));
}
if (initial_basis.has_value()) {
*ret.mutable_initial_basis() = initial_basis->Proto();
}
for (const SolutionHint& solution_hint : solution_hints) {
*ret.add_solution_hints() = solution_hint.Proto();
}
@@ -226,5 +208,64 @@ ModelSolveParametersProto ModelSolveParameters::Proto() const {
return ret;
}
absl::StatusOr<ModelSolveParameters> ModelSolveParameters::FromProto(
const Model& model, const ModelSolveParametersProto& proto) {
ModelSolveParameters result;
OR_ASSIGN_OR_RETURN3(
result.variable_values_filter,
VariableFilterFromProto(model, proto.variable_values_filter()),
_ << "invalid variable_values_filter");
OR_ASSIGN_OR_RETURN3(
result.dual_values_filter,
LinearConstraintFilterFromProto(model, proto.dual_values_filter()),
_ << "invalid dual_values_filter");
OR_ASSIGN_OR_RETURN3(
result.reduced_costs_filter,
VariableFilterFromProto(model, proto.reduced_costs_filter()),
_ << "invalid reduced_costs_filter");
if (proto.has_initial_basis()) {
OR_ASSIGN_OR_RETURN3(
result.initial_basis,
Basis::FromProto(model.storage(), proto.initial_basis()),
_ << "invalid initial_basis");
}
for (int i = 0; i < proto.solution_hints_size(); ++i) {
OR_ASSIGN_OR_RETURN3(
SolutionHint hint,
SolutionHint::FromProto(model, proto.solution_hints(i)),
_ << "invalid solution_hints[" << i << "]");
result.solution_hints.push_back(std::move(hint));
}
OR_ASSIGN_OR_RETURN3(
result.branching_priorities,
VariableValuesFromProto(model.storage(), proto.branching_priorities()),
_ << "invalid branching_priorities");
if (proto.has_primary_objective_parameters()) {
result.objective_parameters.try_emplace(
Objective::Primary(model.storage()),
ObjectiveParameters::FromProto(proto.primary_objective_parameters()));
}
for (const auto& [id, aux_obj_params_proto] :
proto.auxiliary_objective_parameters()) {
if (!model.has_auxiliary_objective(id)) {
return util::InvalidArgumentErrorBuilder()
<< "invalid auxiliary_objective_parameters with id: " << id
<< ", objective not in the model";
}
result.objective_parameters.try_emplace(
Objective::Auxiliary(model.storage(), AuxiliaryObjectiveId{id}),
ObjectiveParameters::FromProto(aux_obj_params_proto));
}
for (int64_t lin_con : proto.lazy_linear_constraint_ids()) {
if (!model.has_linear_constraint(lin_con)) {
return util::InvalidArgumentErrorBuilder()
<< "invalid lazy_linear_constraint with id: " << lin_con
<< ", constraint not in the model";
}
result.lazy_linear_constraints.insert(model.linear_constraint(lin_con));
}
return result;
}
} // namespace math_opt
} // namespace operations_research

View File

@@ -182,6 +182,8 @@ struct ModelSolveParameters {
// Returns the proto equivalent of this object.
ObjectiveParametersProto Proto() const;
static ObjectiveParameters FromProto(const ObjectiveParametersProto& proto);
};
// Parameters for individual objectives in a multi-objective model.
ObjectiveMap<ObjectiveParameters> objective_parameters;
@@ -204,6 +206,11 @@ struct ModelSolveParameters {
// The caller should use CheckModelStorage() as this function does not check
// internal consistency of the referenced variables and constraints.
ModelSolveParametersProto Proto() const;
// Returns the ModelSolveParameters corresponding to this proto and the given
// model.
static absl::StatusOr<ModelSolveParameters> FromProto(
const Model& model, const ModelSolveParametersProto& proto);
};
////////////////////////////////////////////////////////////////////////////////

View File

@@ -187,13 +187,8 @@ absl::StatusOr<Basis> Basis::FromProto(const ModelStorage* model,
basis.variable_status,
VariableBasisFromProto(model, basis_proto.variable_status()),
_ << "invalid variable_status");
const std::optional<SolutionStatus> basic_dual_feasibility =
basis.basic_dual_feasibility =
EnumFromProto(basis_proto.basic_dual_feasibility());
if (!basic_dual_feasibility.has_value()) {
return absl::InvalidArgumentError(
"basic_dual_feasibility for a basis must be specified");
}
basis.basic_dual_feasibility = *basic_dual_feasibility;
return basis;
}

View File

@@ -211,7 +211,6 @@ struct Basis {
// Returns an error if:
// * VariableBasisFromProto(basis_proto.variable_status) fails.
// * LinearConstraintBasisFromProto(basis_proto.constraint_status) fails.
// * basis_proto.basic_dual_feasibility is unspecified.
static absl::StatusOr<Basis> FromProto(const ModelStorage* model,
const BasisProto& basis_proto);
@@ -238,8 +237,9 @@ struct Basis {
//
// If you are providing a starting basis via
// `ModelSolveParameters.initial_basis`, this value is ignored. It is only
// relevant for the basis returned by `Solution.basis`.
SolutionStatus basic_dual_feasibility = SolutionStatus::kUndetermined;
// relevant for the basis returned by `Solution.basis`, and it is is always
// populated in a Basis returned by a call to Solve().
std::optional<SolutionStatus> basic_dual_feasibility;
};
// What is included in a solution depends on the kind of problem and solver.

View File

@@ -133,6 +133,13 @@ absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
return MakeView(vars_proto).as_map<Variable>(model);
}
absl::StatusOr<VariableMap<int32_t>> VariableValuesFromProto(
const ModelStorage* model, const SparseInt32VectorProto& vars_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto));
RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids()));
return MakeView(vars_proto).as_map<Variable>(model);
}
SparseDoubleVectorProto VariableValuesToProto(
const VariableMap<double>& variable_values) {
return MapToProto(variable_values);

View File

@@ -44,13 +44,23 @@ namespace operations_research::math_opt {
// Requires that (or returns a status error):
// * vars_proto.ids and vars_proto.values have equal size.
// * vars_proto.ids is sorted.
// * vars_proto.ids has elements in [0, max(int64_t)).
// * vars_proto.ids has elements that are variables in `model`.
// * vars_proto.ids has elements that are variables in `model` (this implies
// that each id is in [0, max(int64_t))).
//
// Note that the values of vars_proto.values are not checked (it may have NaNs).
absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
const ModelStorage* model, const SparseDoubleVectorProto& vars_proto);
// Returns the VariableMap<int32_t> equivalent to `vars_proto`.
//
// Requires that (or returns a status error):
// * vars_proto.ids and vars_proto.values have equal size.
// * vars_proto.ids is sorted.
// * vars_proto.ids has elements that are variables in `model` (this implies
// that each id is in [0, max(int64_t))).
absl::StatusOr<VariableMap<int32_t>> VariableValuesFromProto(
const ModelStorage* model, const SparseInt32VectorProto& vars_proto);
// Returns the proto equivalent of variable_values.
SparseDoubleVectorProto VariableValuesToProto(
const VariableMap<double>& variable_values);
@@ -79,8 +89,8 @@ google::protobuf::Map<int64_t, double> AuxiliaryObjectiveValuesToProto(
// Requires that (or returns a status error):
// * lin_cons_proto.ids and lin_cons_proto.values have equal size.
// * lin_cons_proto.ids is sorted.
// * lin_cons_proto.ids has elements in [0, max(int64_t)).
// * lin_cons_proto.ids has elements that are linear constraints in `model`.
// * lin_cons_proto.ids has elements that are linear constraints in `model`
// (this implies that each id is in [0, max(int64_t))).
//
// Note that the values of lin_cons_proto.values are not checked (it may have
// NaNs).
@@ -96,8 +106,8 @@ SparseDoubleVectorProto LinearConstraintValuesToProto(
// Requires that (or returns a status error):
// * basis_proto.ids and basis_proto.values have equal size.
// * basis_proto.ids is sorted.
// * basis_proto.ids has elements in [0, max(int64_t)).
// * basis_proto.ids has elements that are variables in `model`.
// * basis_proto.ids has elements that are variables in `model` (this implies
// that each id is in [0, max(int64_t))).
// * basis_proto.values does not contain UNSPECIFIED and has valid enum values.
absl::StatusOr<VariableMap<BasisStatus>> VariableBasisFromProto(
const ModelStorage* model, const SparseBasisStatusVector& basis_proto);
@@ -111,8 +121,8 @@ SparseBasisStatusVector VariableBasisToProto(
// Requires that (or returns a status error):
// * basis_proto.ids and basis_proto.values have equal size.
// * basis_proto.ids is sorted.
// * basis_proto.ids has elements in [0, max(int64_t)).
// * basis_proto.ids has elements that are linear constraints in `model`.
// * basis_proto.ids has elements that are linear constraints in `model` (this
// implies that each id is in [0, max(int64_t))).
// * basis_proto.values does not contain UNSPECIFIED and has valid enum values.
absl::StatusOr<LinearConstraintMap<BasisStatus>> LinearConstraintBasisFromProto(
const ModelStorage* model, const SparseBasisStatusVector& basis_proto);

View File

@@ -92,10 +92,26 @@ message SolveRequest {
// Response for a unary remote solve in MathOpt.
message SolveResponse {
// Either `result` or `status` must be set. This is equivalent to C++
// StatusOr<SolveResult>.
oneof status_or {
// Description of the output of solving the model in the request.
SolveResultProto result = 1;
// The absl::Status returned by the solver. It should never be OK when set.
StatusProto status = 3;
}
// If SolveParametersProto.enable_output has been used, this will contain log
// messages for solvers that support message callbacks.
repeated string messages = 2;
}
// The streamed version of absl::Status.
message StatusProto {
// The status code, one of the absl::StatusCode.
int32 code = 1;
// The status message.
string message = 2;
}

View File

@@ -17,7 +17,7 @@ package(default_visibility = ["//visibility:private"])
cc_binary(
name = "mathopt_solve",
srcs = ["mathopt_solve_main.cc"],
srcs = ["mathopt_solve.cc"],
deps = [
":file_format_flags",
"//ortools/base",
@@ -34,7 +34,9 @@ cc_binary(
"//ortools/math_opt/solvers:gscip_solver",
"//ortools/math_opt/solvers:highs_solver",
"//ortools/math_opt/solvers:pdlp_solver",
"//ortools/util:sigint",
"//ortools/util:status_macros",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
@@ -46,7 +48,7 @@ cc_binary(
cc_binary(
name = "mathopt_convert",
srcs = ["mathopt_convert_main.cc"],
srcs = ["mathopt_convert.cc"],
deps = [
":file_format_flags",
"//ortools/base",

View File

@@ -32,6 +32,7 @@
#include <utility>
#include <vector>
#include "absl/base/no_destructor.h"
#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
@@ -53,6 +54,7 @@
#include "ortools/math_opt/labs/solution_feasibility_checker.h"
#include "ortools/math_opt/parameters.pb.h"
#include "ortools/math_opt/tools/file_format_flags.h"
#include "ortools/util/sigint.h"
#include "ortools/util/status_macros.h"
namespace {
@@ -105,6 +107,9 @@ ABSL_FLAG(bool, solver_logs, false,
"use a message callback to print the solver convergence logs");
ABSL_FLAG(absl::Duration, time_limit, absl::InfiniteDuration(),
"the time limit to use for the solve");
ABSL_FLAG(bool, sigint_interrupt, true,
"interrupts the solve on the first SIGINT; kill the process on the "
"third one");
ABSL_FLAG(bool, names, true,
"use the names in the input models; ignoring names is useful when "
@@ -273,18 +278,35 @@ absl::Status PrintSummary(const Model& model, const SolveResult& result,
absl::StatusOr<SolveResult> LocalOrRemoteSolve(
const Model& model, const SolverType solver_type,
const SolveParameters& params, const ModelSolveParameters& model_params,
MessageCallback msg_cb) {
MessageCallback msg_cb, SolveInterrupter* interrupter) {
if (absl::GetFlag(FLAGS_remote)) {
return absl::UnimplementedError("remote not yet supported.");
} else {
return Solve(model, solver_type,
{.parameters = params,
.model_parameters = model_params,
.message_callback = std::move(msg_cb)});
.message_callback = std::move(msg_cb),
.interrupter = interrupter});
}
}
absl::Status RunSolver() {
// We use absl::NoDestructor here so that the SIGINT handler is kept until the
// very end of the process, making sure a late Ctrl-C on the very end of the
// solve don't kill the process.
static absl::NoDestructor<SigintHandler> sigint_handler;
static const absl::NoDestructor<std::unique_ptr<SolveInterrupter>>
interrupter([&]() -> std::unique_ptr<SolveInterrupter> {
if (!absl::GetFlag(FLAGS_sigint_interrupt)) {
return nullptr;
}
auto interrupter =
std::make_unique<operations_research::SolveInterrupter>();
sigint_handler->Register(
[interrupter = interrupter.get()]() { interrupter->Interrupt(); });
return interrupter;
}());
if (absl::GetFlag(FLAGS_remote) &&
absl::GetFlag(FLAGS_time_limit) == absl::InfiniteDuration()) {
return absl::InvalidArgumentError(
@@ -318,9 +340,9 @@ absl::Status RunSolver() {
}
OR_ASSIGN_OR_RETURN3(
const SolveResult result,
LocalOrRemoteSolve(*model_and_hint.model,
absl::GetFlag(FLAGS_solver_type), solve_params,
model_params, std::move(message_cb)),
LocalOrRemoteSolve(
*model_and_hint.model, absl::GetFlag(FLAGS_solver_type), solve_params,
model_params, std::move(message_cb), interrupter->get()),
_ << "the solver failed");
const FeasibilityCheckerOptions feasibility_checker_options = {