Files
ortools-clone/ortools/math_opt/solver.cc
2021-04-11 12:05:38 +02:00

168 lines
6.0 KiB
C++

// Copyright 2010-2021 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/solver.h"
#include <stdint.h>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "ortools/base/integral_types.h"
#include "ortools/base/logging.h"
#include "ortools/base/status_macros.h"
#include "ortools/math_opt/callback.pb.h"
#include "ortools/math_opt/model.pb.h"
#include "ortools/math_opt/model_parameters.pb.h"
#include "ortools/math_opt/model_summary.h"
#include "ortools/math_opt/model_update.pb.h"
#include "ortools/math_opt/parameters.pb.h"
#include "ortools/math_opt/result.pb.h"
#include "ortools/math_opt/solver_interface.h"
#include "ortools/math_opt/validators/callback_validator.h"
#include "ortools/math_opt/validators/model_parameters_validator.h"
#include "ortools/math_opt/validators/model_validator.h"
#include "ortools/math_opt/validators/solution_validator.h"
#include "ortools/math_opt/validators/solver_parameters_validator.h"
namespace operations_research {
namespace math_opt {
namespace {
template <typename IdNameContainer>
void UpdateIdNameMap(const absl::Span<const int64_t> deleted_ids,
const IdNameContainer& container, IdNameBiMap& bimap) {
for (const int64_t deleted_id : deleted_ids) {
bimap.Erase(deleted_id);
}
for (int i = 0; i < container.ids_size(); ++i) {
std::string name;
if (!container.names().empty()) {
name = container.names(i);
}
bimap.Insert(container.ids(i), std::move(name));
}
}
ModelSummary MakeSummary(const ModelProto& model) {
ModelSummary summary;
UpdateIdNameMap<VariablesProto>({}, model.variables(), summary.variables);
UpdateIdNameMap<LinearConstraintsProto>({}, model.linear_constraints(),
summary.linear_constraints);
return summary;
}
void UpdateSummaryFromModelUpdate(const ModelUpdateProto& model_update,
ModelSummary& summary) {
UpdateIdNameMap<VariablesProto>(model_update.deleted_variable_ids(),
model_update.new_variables(),
summary.variables);
UpdateIdNameMap<LinearConstraintsProto>(
model_update.deleted_linear_constraint_ids(),
model_update.new_linear_constraints(), summary.linear_constraints);
}
// Returns an InternalError with the input status message if the input status is
// not OK.
absl::Status ToInternalError(const absl::Status original) {
if (original.ok()) {
return original;
}
return absl::InternalError(original.message());
}
} // namespace
Solver::Solver(std::unique_ptr<SolverInterface> underlying_solver,
ModelSummary model_summary)
: underlying_solver_(std::move(underlying_solver)),
model_summary_(std::move(model_summary)) {
CHECK(underlying_solver_ != nullptr);
}
absl::StatusOr<std::unique_ptr<Solver>> Solver::New(
const SolverType solver_type, const ModelProto& model,
const SolverInitializerProto& initializer) {
RETURN_IF_ERROR(ValidateModel(model));
ASSIGN_OR_RETURN(
auto underlying_solver,
AllSolversRegistry::Instance()->Create(solver_type, model, initializer));
auto result = absl::WrapUnique(
new Solver(std::move(underlying_solver), MakeSummary(model)));
return result;
}
absl::StatusOr<SolveResultProto> Solver::Solve(
const SolveParametersProto& parameters,
const ModelSolveParametersProto& model_parameters,
const CallbackRegistrationProto& callback_registration,
const Callback user_cb) {
// TODO(b/168037341): we should validate the result maths. Since the result
// can be filtered, this should be included in the solver_interface
// implementations.
RETURN_IF_ERROR(ValidateSolverParameters(parameters)) << "invalid parameters";
RETURN_IF_ERROR(
ValidateModelSolveParameters(model_parameters, model_summary_))
<< "invalid model_parameters";
SolverInterface::Callback cb = nullptr;
if (user_cb != nullptr) {
RETURN_IF_ERROR(
ValidateCallbackRegistration(callback_registration, model_summary_));
cb = [&](const CallbackDataProto& callback_data)
-> absl::StatusOr<CallbackResultProto> {
RETURN_IF_ERROR(ValidateCallbackDataProto(
callback_data, callback_registration, model_summary_));
auto callback_result = user_cb(callback_data);
RETURN_IF_ERROR(
ValidateCallbackResultProto(callback_result, callback_data.event(),
callback_registration, model_summary_));
return callback_result;
};
}
ASSIGN_OR_RETURN(const SolveResultProto result,
underlying_solver_->Solve(parameters, model_parameters,
callback_registration, cb));
// We consider errors in `result` to be internal errors, but
// `ValidateResult()` will return an InvalidArgumentError. So here we convert
// the error.
RETURN_IF_ERROR(ToInternalError(
ValidateResult(result, model_parameters, model_summary_)));
return result;
}
absl::StatusOr<bool> Solver::Update(const ModelUpdateProto& model_update) {
RETURN_IF_ERROR(ValidateModelUpdateAndSummary(model_update, model_summary_));
if (!underlying_solver_->CanUpdate(model_update)) {
return false;
}
UpdateSummaryFromModelUpdate(model_update, model_summary_);
RETURN_IF_ERROR(underlying_solver_->Update(model_update));
return true;
}
} // namespace math_opt
} // namespace operations_research