diff --git a/ortools/math_opt/core/BUILD.bazel b/ortools/math_opt/core/BUILD.bazel index 4ce3175cda..566a1b158a 100644 --- a/ortools/math_opt/core/BUILD.bazel +++ b/ortools/math_opt/core/BUILD.bazel @@ -39,7 +39,12 @@ cc_library( "//ortools/base", "//ortools/base:linked_hash_map", "//ortools/base:map_util", + "//ortools/base:status_builder", + "//ortools/base:status_macros", + "//ortools/math_opt:model_cc_proto", + "//ortools/math_opt:model_update_cc_proto", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], ) diff --git a/ortools/math_opt/core/model_storage.cc b/ortools/math_opt/core/model_storage.cc index 5409fb274a..604a9763d5 100644 --- a/ortools/math_opt/core/model_storage.cc +++ b/ortools/math_opt/core/model_storage.cc @@ -151,7 +151,7 @@ absl::StatusOr> ModelStorage::FromModelProto( // models. Thus a model built by ModelStorage can contain duplicated // names. And since we use FromModelProto() to implement Clone(), we must make // sure duplicated names don't fail. - RETURN_IF_ERROR(ValidateModel(model_proto, /*check_names=*/false)); + RETURN_IF_ERROR(ValidateModel(model_proto, /*check_names=*/false).status()); auto storage = std::make_unique(model_proto.name()); @@ -810,21 +810,23 @@ absl::Status ModelStorage::ApplyUpdateProto( const ModelUpdateProto& update_proto) { // Check the update first. { - ModelSummary summary; - // We have to use sorted keys since IdNameBiMap expect Insert() to be called - // in sorted order. + // Do not check for duplicate names, as with FromModelProto(); + ModelSummary summary(/*check_names=*/false); + // IdNameBiMap requires Insert() calls to be in sorted id order. for (const auto id : SortedVariables()) { - summary.variables.Insert(id.value(), variable_name(id)); + RETURN_IF_ERROR(summary.variables.Insert(id.value(), variable_name(id))) + << "invalid variable id in model"; } - summary.variables.SetNextFreeId(next_variable_id_.value()); + RETURN_IF_ERROR(summary.variables.SetNextFreeId(next_variable_id_.value())); for (const auto id : SortedLinearConstraints()) { - summary.linear_constraints.Insert(id.value(), linear_constraint_name(id)); + RETURN_IF_ERROR(summary.linear_constraints.Insert( + id.value(), linear_constraint_name(id))) + << "invalid linear constraint id in model"; } - summary.linear_constraints.SetNextFreeId( - next_linear_constraint_id_.value()); - // We don't check the names for the same reason as in FromModelProto(). - RETURN_IF_ERROR(ValidateModelUpdateAndSummary(update_proto, summary, - /*check_names=*/false)); + RETURN_IF_ERROR(summary.linear_constraints.SetNextFreeId( + next_linear_constraint_id_.value())); + RETURN_IF_ERROR(ValidateModelUpdate(update_proto, summary)) + << "update not valid"; } // Remove deleted variables and constraints. diff --git a/ortools/math_opt/core/model_summary.cc b/ortools/math_opt/core/model_summary.cc index 6352828dfe..5e254614b2 100644 --- a/ortools/math_opt/core/model_summary.cc +++ b/ortools/math_opt/core/model_summary.cc @@ -22,13 +22,117 @@ namespace operations_research { namespace math_opt { +namespace { + +// TODO(b/232526223): this is an exact copy of +// CheckIdsRangeAndStrictlyIncreasing from ids_validator.h, find a way to share +// the code. +absl::Status CheckIdsRangeAndStrictlyIncreasing2( + absl::Span ids) { + int64_t previous{-1}; + for (int i = 0; i < ids.size(); previous = ids[i], ++i) { + if (ids[i] < 0 || ids[i] == std::numeric_limits::max()) { + return util::InvalidArgumentErrorBuilder() + << "Expected ids to be nonnegative and not max(int64_t) but at " + "index " + << i << " found id: " << ids[i]; + } + if (ids[i] <= previous) { + return util::InvalidArgumentErrorBuilder() + << "Expected ids to be strictly increasing, but at index " << i + << " found id: " << ids[i] << " and at previous index " << i - 1 + << " found id: " << ids[i - 1]; + } + } + return absl::OkStatus(); +} + +} // namespace IdNameBiMap::IdNameBiMap( - std::initializer_list> ids) { + std::initializer_list> ids) + : IdNameBiMap(/*check_names=*/true) { for (const auto& pair : ids) { - Insert(pair.first, std::string(pair.second)); + CHECK_OK(Insert(pair.first, std::string(pair.second))); } } +IdNameBiMap::IdNameBiMap(const IdNameBiMap& other) { *this = other; } + +IdNameBiMap& IdNameBiMap::operator=(const IdNameBiMap& other) { + if (&other == this) { + return *this; + } + next_free_id_ = other.next_free_id_; + id_to_name_ = other.id_to_name_; + if (!other.nonempty_name_to_id_.has_value()) { + nonempty_name_to_id_ = std::nullopt; + } else { + nonempty_name_to_id_.emplace(); + for (const auto& [id, name] : id_to_name_) { + if (!name.empty()) { + const auto [it, success] = + nonempty_name_to_id_->insert({absl::string_view(name), id}); + CHECK(success); // CHECK is OK, other cannot have duplicate names and + // non nullopt nonempty_name_to_id_. + } + } + } + + return *this; +} + +absl::Status IdNameBiMap::BulkUpdate( + absl::Span deleted_ids, absl::Span new_ids, + const absl::Span names) { + RETURN_IF_ERROR(CheckIdsRangeAndStrictlyIncreasing2(deleted_ids)) + << "invalid deleted ids"; + RETURN_IF_ERROR(CheckIdsRangeAndStrictlyIncreasing2(new_ids)) + << "invalid new ids"; + if (!names.empty() && names.size() != new_ids.size()) { + return util::InvalidArgumentErrorBuilder() + << "names had size " << names.size() + << " but should either be empty of have size matching new_ids which " + "has size " + << new_ids.size(); + } + for (const int64_t id : deleted_ids) { + RETURN_IF_ERROR(Erase(id)); + } + for (int i = 0; i < new_ids.size(); ++i) { + RETURN_IF_ERROR( + Insert(new_ids[i], names.empty() ? std::string{} : *names[i])); + } + return absl::OkStatus(); +} + +ModelSummary::ModelSummary(const bool check_names) + : variables(check_names), linear_constraints(check_names) {} + +absl::StatusOr ModelSummary::Create(const ModelProto& model, + const bool check_names) { + ModelSummary summary(check_names); + RETURN_IF_ERROR(summary.variables.BulkUpdate({}, model.variables().ids(), + model.variables().names())) + << "Model.variables are invalid"; + RETURN_IF_ERROR(summary.linear_constraints.BulkUpdate( + {}, model.linear_constraints().ids(), model.linear_constraints().names())) + << "Model.linear_constraints are invalid"; + return summary; +} + +absl::Status ModelSummary::Update(const ModelUpdateProto& model_update) { + RETURN_IF_ERROR(variables.BulkUpdate(model_update.deleted_variable_ids(), + model_update.new_variables().ids(), + model_update.new_variables().names())) + << "invalid variables"; + RETURN_IF_ERROR(linear_constraints.BulkUpdate( + model_update.deleted_linear_constraint_ids(), + model_update.new_linear_constraints().ids(), + model_update.new_linear_constraints().names())) + << "invalid linear constraints"; + return absl::OkStatus(); +} + } // namespace math_opt } // namespace operations_research diff --git a/ortools/math_opt/core/model_summary.h b/ortools/math_opt/core/model_summary.h index d51c4522fe..3d9b95a776 100644 --- a/ortools/math_opt/core/model_summary.h +++ b/ortools/math_opt/core/model_summary.h @@ -21,10 +21,15 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "ortools/base/linked_hash_map.h" #include "ortools/base/logging.h" #include "ortools/base/map_util.h" +#include "ortools/base/status_builder.h" +#include "ortools/base/status_macros.h" +#include "ortools/math_opt/model.pb.h" +#include "ortools/math_opt/model_update.pb.h" namespace operations_research { namespace math_opt { @@ -37,25 +42,46 @@ namespace math_opt { // * Ids are non-negative. // * Ids are not equal to std::numeric_limits::max() // * Ids removed are never reused. -// * Names must be either empty or unique. +// * Names must be either empty or unique when built with check_names=true. class IdNameBiMap { public: - IdNameBiMap() = default; + // If check_names=false, the names need not be unique and the reverse mapping + // of name to id is not available. + explicit IdNameBiMap(bool check_names = true) + : nonempty_name_to_id_( + check_names ? std::make_optional< + absl::flat_hash_map>() + : std::nullopt) {} + + // Needs a custom copy constructor/assign because absl::string_view to + // internal data is held as a member. No custom move is needed. + IdNameBiMap(const IdNameBiMap& other); + IdNameBiMap& operator=(const IdNameBiMap& other); + IdNameBiMap(IdNameBiMap&& other) = default; + IdNameBiMap& operator=(IdNameBiMap&& other) = default; // This constructor CHECKs that the input ids are sorted in increasing // order. This constructor is expected to be used only for unit tests of // validation code. IdNameBiMap(std::initializer_list> ids); - // Inserts the provided id and associate the provided name to it. CHECKs that - // id >= next_free_id() and that when the name is nonempty it is not already - // present. As a side effect it updates next_free_id to id + 1. - inline void Insert(int64_t id, std::string name); + // Inserts the provided id and associate the provided name to it. + // + // An error is returned if: + // * id is negative + // * id is not at least next_free_id() + // * id is max(int64_t) + // * name is a duplicated and check_names was true at construction. + // + // As a side effect it updates next_free_id to id + 1. + inline absl::Status Insert(int64_t id, std::string name); - // Removes the given id. CHECKs that it is present. - inline void Erase(int64_t id); + // Removes the given id, or returns an error if id is not present. + inline absl::Status Erase(int64_t id); inline bool HasId(int64_t id) const; + + // Will always return false if name is empty or if check_names was false. inline bool HasName(absl::string_view name) const; inline bool Empty() const; inline int Size() const; @@ -64,23 +90,26 @@ class IdNameBiMap { // non-negative). inline int64_t next_free_id() const; - // Updates next_free_id(). CHECKs that the provided id is greater than any - // exiting id and non negative. - // - // In practice this should only be used to increase the next_free_id() value - // in cases where a ModelSummary is built with an existing model but we know - // some ids of removed elements have already been used. - inline void SetNextFreeId(int64_t new_next_free_id); + // Updates next_free_id(). Succeeds when the provided id is greater than every + // exiting id and is non-negative. + inline absl::Status SetNextFreeId(int64_t new_next_free_id); // Iteration order is in increasing id order. const gtl::linked_hash_map& id_to_name() const { return id_to_name_; } - const absl::flat_hash_map& nonempty_name_to_id() - const { + + // Is std::nullopt if check_names was false at construction. + const std::optional>& + nonempty_name_to_id() const { return nonempty_name_to_id_; } + // Warning: this may be mutated (partially updated) if an error is returned. + absl::Status BulkUpdate(absl::Span deleted_ids, + absl::Span new_ids, + absl::Span names); + private: // Next unused id. int64_t next_free_id_ = 0; @@ -88,10 +117,18 @@ class IdNameBiMap { // Pointer stability for name strings and iterating in insertion order are // both needed (so we do not use flat_hash_map). gtl::linked_hash_map id_to_name_; - absl::flat_hash_map nonempty_name_to_id_; + std::optional> + nonempty_name_to_id_; }; +// TODO(b/232619901): In the guide for how to add new constraints, include how +// this class must updated. struct ModelSummary { + explicit ModelSummary(bool check_names = true); + static absl::StatusOr Create(const ModelProto& model, + bool check_names = true); + absl::Status Update(const ModelUpdateProto& model_update); + IdNameBiMap variables; IdNameBiMap linear_constraints; }; @@ -100,35 +137,61 @@ struct ModelSummary { // Inline function implementations //////////////////////////////////////////////////////////////////////////////// -void IdNameBiMap::Insert(const int64_t id, std::string name) { - CHECK_GE(id, next_free_id_); - CHECK_LT(id, std::numeric_limits::max()); +absl::Status IdNameBiMap::Insert(const int64_t id, std::string name) { + if (id < next_free_id_) { + return util::InvalidArgumentErrorBuilder() + << "expected id=" << id + << " to be at least next_free_id_=" << next_free_id_ + << " (ids should be nonnegative and inserted in strictly increasing " + "order)"; + } + if (id == std::numeric_limits::max()) { + return absl::InvalidArgumentError("id of max(int64_t) is not allowed"); + } next_free_id_ = id + 1; const auto [it, success] = id_to_name_.emplace(id, std::move(name)); - CHECK(success) << "id: " << id; + CHECK(success); // CHECK is okay, we have the invariant that next_free_id_ is + // larger than everything in the map. const absl::string_view name_view(it->second); - if (!name_view.empty()) { - gtl::InsertOrDie(&nonempty_name_to_id_, name_view, id); + if (nonempty_name_to_id_.has_value() && !name_view.empty()) { + const auto [it, success] = nonempty_name_to_id_->insert({name_view, id}); + if (!success) { + return util::InvalidArgumentErrorBuilder() + << "duplicate name inserted: " << name_view; + } } + return absl::OkStatus(); } -void IdNameBiMap::Erase(const int64_t id) { +absl::Status IdNameBiMap::Erase(const int64_t id) { const auto it = id_to_name_.find(id); - CHECK(it != id_to_name_.end()) << id; + if (it == id_to_name_.end()) { + return util::InvalidArgumentErrorBuilder() + << "cannot delete missing id " << id; + } const absl::string_view name_view(it->second); - if (!name_view.empty()) { - CHECK_EQ(1, nonempty_name_to_id_.erase(name_view)) + if (nonempty_name_to_id_.has_value() && !name_view.empty()) { + // CHECK OK, name_view being in nonempty_name_to_id_ when the above is met + // is a class invariant. + CHECK_EQ(nonempty_name_to_id_->erase(name_view), 1) << "name: " << name_view << " id: " << id; } id_to_name_.erase(it); + return absl::OkStatus(); } + bool IdNameBiMap::HasId(const int64_t id) const { return id_to_name_.contains(id); } bool IdNameBiMap::HasName(const absl::string_view name) const { - CHECK(!name.empty()); - return nonempty_name_to_id_.contains(name); + if (name.empty()) { + return false; + } + if (!nonempty_name_to_id_.has_value()) { + return false; + } + return nonempty_name_to_id_->contains(name); } bool IdNameBiMap::Empty() const { return id_to_name_.empty(); } @@ -137,14 +200,23 @@ int IdNameBiMap::Size() const { return id_to_name_.size(); } int64_t IdNameBiMap::next_free_id() const { return next_free_id_; } -void IdNameBiMap::SetNextFreeId(const int64_t new_next_free_id) { +absl::Status IdNameBiMap::SetNextFreeId(const int64_t new_next_free_id) { if (!Empty()) { const int64_t largest_id = id_to_name_.back().first; - CHECK_GT(new_next_free_id, largest_id); + if (new_next_free_id <= largest_id) { + return util::InvalidArgumentErrorBuilder() + << "new_next_free_id=" << new_next_free_id + << " must be greater than largest_id=" << largest_id; + } } else { - CHECK_GE(new_next_free_id, 0); + if (new_next_free_id < 0) { + return util::InvalidArgumentErrorBuilder() + << "new_next_free_id=" << new_next_free_id + << " must be nonnegative"; + } } next_free_id_ = new_next_free_id; + return absl::OkStatus(); } } // namespace math_opt diff --git a/ortools/math_opt/core/solver.cc b/ortools/math_opt/core/solver.cc index 6e9bc14dda..d0dc33901e 100644 --- a/ortools/math_opt/core/solver.cc +++ b/ortools/math_opt/core/solver.cc @@ -51,39 +51,6 @@ namespace math_opt { namespace { -template -void UpdateIdNameMap(const absl::Span deleted_ids, - const IdNameContainer& container, IdNameBiMap& bimap) { - for (const int64_t deleted_id : deleted_ids) { - bimap.Erase(deleted_id); - } - for (int i = 0; i < container.ids_size(); ++i) { - std::string name; - if (!container.names().empty()) { - name = container.names(i); - } - bimap.Insert(container.ids(i), std::move(name)); - } -} - -ModelSummary MakeSummary(const ModelProto& model) { - ModelSummary summary; - UpdateIdNameMap({}, model.variables(), summary.variables); - UpdateIdNameMap({}, model.linear_constraints(), - summary.linear_constraints); - return summary; -} - -void UpdateSummaryFromModelUpdate(const ModelUpdateProto& model_update, - ModelSummary& summary) { - UpdateIdNameMap(model_update.deleted_variable_ids(), - model_update.new_variables(), - summary.variables); - UpdateIdNameMap( - model_update.deleted_linear_constraint_ids(), - model_update.new_linear_constraints(), summary.linear_constraints); -} - // Returns an InternalError with the input status message if the input status is // not OK. absl::Status ToInternalError(const absl::Status original) { @@ -188,12 +155,12 @@ absl::StatusOr> Solver::New( const SolverTypeProto solver_type, const ModelProto& model, const InitArgs& arguments) { RETURN_IF_ERROR(internal::ValidateInitArgs(arguments, solver_type)); - RETURN_IF_ERROR(ValidateModel(model)); + ASSIGN_OR_RETURN(ModelSummary summary, ValidateModel(model)); ASSIGN_OR_RETURN( auto underlying_solver, AllSolversRegistry::Instance()->Create(solver_type, model, arguments)); auto result = absl::WrapUnique( - new Solver(std::move(underlying_solver), MakeSummary(model))); + new Solver(std::move(underlying_solver), std::move(summary))); return result; } @@ -259,13 +226,18 @@ absl::StatusOr Solver::Update(const ModelUpdateProto& model_update) { // We will reset it in code paths where no error occur. fatal_failure_occurred_ = true; + // TODO(b/232264333): we are modifying model_summary_ but when CanUpdate + // returns false we are not in a good state. With a different design we can + // avoid this copy, which can be non-negligible. + ModelSummary backup = model_summary_; + RETURN_IF_ERROR(ValidateModelUpdate(model_update, model_summary_)); - RETURN_IF_ERROR(ValidateModelUpdateAndSummary(model_update, model_summary_)); if (!underlying_solver_->CanUpdate(model_update)) { + model_summary_ = std::move(backup); fatal_failure_occurred_ = false; return false; } - UpdateSummaryFromModelUpdate(model_update, model_summary_); + RETURN_IF_ERROR(underlying_solver_->Update(model_update)); fatal_failure_occurred_ = false; diff --git a/ortools/math_opt/cpp/BUILD.bazel b/ortools/math_opt/cpp/BUILD.bazel index 0c14241f4e..cef3cda279 100644 --- a/ortools/math_opt/cpp/BUILD.bazel +++ b/ortools/math_opt/cpp/BUILD.bazel @@ -12,6 +12,42 @@ cc_library( ], ) +cc_library( + name = "basis_status", + srcs = ["basis_status.cc"], + hdrs = ["basis_status.h"], + deps = [ + ":enums", + "//ortools/math_opt:solution_cc_proto", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "sparse_containers", + srcs = ["sparse_containers.cc"], + hdrs = ["sparse_containers.h"], + deps = [ + ":basis_status", + ":linear_constraint", + ":variable_and_expressions", + "//ortools/base", + "//ortools/base:status_macros", + "//ortools/math_opt:solution_cc_proto", + "//ortools/math_opt:sparse_containers_cc_proto", + "//ortools/math_opt/core:model_storage", + "//ortools/math_opt/core:sparse_vector_view", + "//ortools/math_opt/validators:ids_validator", + "//ortools/math_opt/validators:sparse_vector_validator", + "//ortools/util:status_macros", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "model", srcs = ["model.cc"], @@ -84,8 +120,10 @@ cc_library( srcs = ["solution.cc"], hdrs = ["solution.h"], deps = [ + ":basis_status", ":enums", ":linear_constraint", + ":sparse_containers", ":variable_and_expressions", "//ortools/base", "//ortools/base:intops", @@ -112,11 +150,13 @@ cc_library( "//ortools/base", "//ortools/base:protoutil", "//ortools/base:status_macros", + "//ortools/util:status_macros", "//ortools/math_opt:result_cc_proto", "//ortools/math_opt:solution_cc_proto", "//ortools/math_opt/core:model_storage", "//ortools/port:proto_utils", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -143,6 +183,7 @@ cc_library( ":enums", ":key_types", ":map_filter", + ":sparse_containers", ":variable_and_expressions", "//ortools/base", "//ortools/base:intops", diff --git a/ortools/math_opt/cpp/basis_status.cc b/ortools/math_opt/cpp/basis_status.cc new file mode 100644 index 0000000000..76df25b0b1 --- /dev/null +++ b/ortools/math_opt/cpp/basis_status.cc @@ -0,0 +1,50 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/cpp/basis_status.h" + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/math_opt/cpp/enums.h" + +namespace operations_research::math_opt { + +std::optional Enum::ToOptString( + BasisStatus value) { + switch (value) { + case BasisStatus::kFree: + return "free"; + case BasisStatus::kAtLowerBound: + return "at_lower_bound"; + case BasisStatus::kAtUpperBound: + return "at_upper_bound"; + case BasisStatus::kFixedValue: + return "fixed_value"; + case BasisStatus::kBasic: + return "basic"; + } + return std::nullopt; +} + +absl::Span Enum::AllValues() { + static constexpr BasisStatus kBasisStatusValues[] = { + BasisStatus::kFree, BasisStatus::kAtLowerBound, + BasisStatus::kAtUpperBound, BasisStatus::kFixedValue, + BasisStatus::kBasic, + }; + return absl::MakeConstSpan(kBasisStatusValues); +} + +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/cpp/basis_status.h b/ortools/math_opt/cpp/basis_status.h new file mode 100644 index 0000000000..5c0dc6a53a --- /dev/null +++ b/ortools/math_opt/cpp/basis_status.h @@ -0,0 +1,46 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_CPP_BASIS_STATUS_H_ +#define OR_TOOLS_MATH_OPT_CPP_BASIS_STATUS_H_ + +#include + +#include "ortools/math_opt/cpp/enums.h" // IWYU pragma: export +#include "ortools/math_opt/solution.pb.h" + +namespace operations_research::math_opt { + +// Status of a variable/constraint in a LP basis. +enum class BasisStatus : int8_t { + // The variable/constraint is free (it has no finite bounds). + kFree = BASIS_STATUS_FREE, + + // The variable/constraint is at its lower bound (which must be finite). + kAtLowerBound = BASIS_STATUS_AT_LOWER_BOUND, + + // The variable/constraint is at its upper bound (which must be finite). + kAtUpperBound = BASIS_STATUS_AT_UPPER_BOUND, + + // The variable/constraint has identical finite lower and upper bounds. + kFixedValue = BASIS_STATUS_FIXED_VALUE, + + // The variable/constraint is basic. + kBasic = BASIS_STATUS_BASIC, +}; + +MATH_OPT_DEFINE_ENUM(BasisStatus, BASIS_STATUS_UNSPECIFIED); + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_CPP_BASIS_STATUS_H_ diff --git a/ortools/math_opt/cpp/callback.cc b/ortools/math_opt/cpp/callback.cc index fbaf58f221..95f80b3374 100644 --- a/ortools/math_opt/cpp/callback.cc +++ b/ortools/math_opt/cpp/callback.cc @@ -31,6 +31,7 @@ #include "ortools/math_opt/core/model_storage.h" #include "ortools/math_opt/core/sparse_vector_view.h" #include "ortools/math_opt/cpp/map_filter.h" +#include "ortools/math_opt/cpp/sparse_containers.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/sparse_containers.pb.h" @@ -38,14 +39,6 @@ namespace operations_research { namespace math_opt { namespace { -std::vector> SortedVariableValues( - const VariableMap& var_map) { - std::vector> result(var_map.raw_map().begin(), - var_map.raw_map().end()); - std::sort(result.begin(), result.end()); - return result; -} - // Container must be an iterable on some type T where // const ModelStorage* T::storage() const // is defined. @@ -111,8 +104,8 @@ CallbackData::CallbackData(const ModelStorage* storage, mip_stats(proto.mip_stats()) { CHECK(EnumFromProto(proto.event()).has_value()); if (proto.has_primal_solution_vector()) { - solution = VariableMap( - storage, MakeView(proto.primal_solution_vector()).as_map()); + solution = VariableValuesFromProto(storage, proto.primal_solution_vector()) + .value(); } auto maybe_time = util_time::DecodeGoogleApiProto(proto.runtime()); CHECK_OK(maybe_time.status()); @@ -153,11 +146,7 @@ CallbackResultProto CallbackResult::Proto() const { CallbackResultProto result; result.set_terminate(terminate); for (const VariableMap& solution : suggested_solutions) { - SparseDoubleVectorProto* solution_vector = result.add_suggested_solutions(); - for (const auto& [typed_id, value] : SortedVariableValues(solution)) { - solution_vector->add_ids(typed_id.value()); - solution_vector->add_values(value); - } + *result.add_suggested_solutions() = VariableValuesToProto(solution); } for (const GeneratedLinearConstraint& constraint : new_constraints) { CallbackResultProto::GeneratedLinearConstraint* constraint_proto = @@ -167,11 +156,8 @@ CallbackResultProto CallbackResult::Proto() const { constraint.linear_constraint.lower_bound_minus_offset()); constraint_proto->set_upper_bound( constraint.linear_constraint.upper_bound_minus_offset()); - for (const auto& [typed_id, value] : SortedVariableValues( - constraint.linear_constraint.expression.terms())) { - constraint_proto->mutable_linear_expression()->add_ids(typed_id.value()); - constraint_proto->mutable_linear_expression()->add_values(value); - } + *constraint_proto->mutable_linear_expression() = + VariableValuesToProto(constraint.linear_constraint.expression.terms()); } return result; } diff --git a/ortools/math_opt/cpp/id_map.h b/ortools/math_opt/cpp/id_map.h index 9205293c12..3cddc33cc5 100644 --- a/ortools/math_opt/cpp/id_map.h +++ b/ortools/math_opt/cpp/id_map.h @@ -158,6 +158,9 @@ class IdMap { inline void insert(InputIt first, InputIt last); inline void insert(std::initializer_list ilist); + template + inline std::pair insert_or_assign(const K& k, M&& v); + inline std::pair emplace(const K& k, V v); template inline std::pair try_emplace(const K& k, Args&&... args); @@ -349,7 +352,7 @@ IdMap::const_iterator::const_iterator( template IdMap::IdMap(const ModelStorage* storage, StorageType values) - : storage_(storage), map_(std::move(values)) { + : storage_(values.empty() ? nullptr : storage), map_(std::move(values)) { if (!map_.empty()) { CHECK(storage_ != nullptr); } @@ -421,6 +424,16 @@ void IdMap::insert(std::initializer_list ilist) { insert(ilist.begin(), ilist.end()); } +template +template +std::pair::iterator, bool> IdMap::insert_or_assign( + const K& k, M&& v) { + CheckOrSetModel(k); + auto initial_ret = map_.insert_or_assign(k.typed_id(), std::forward(v)); + return std::make_pair(iterator(this, std::move(initial_ret.first)), + initial_ret.second); +} + template std::pair::iterator, bool> IdMap::emplace(const K& k, V v) { diff --git a/ortools/math_opt/cpp/id_set.h b/ortools/math_opt/cpp/id_set.h index 5480cf180c..16cdf16ea7 100644 --- a/ortools/math_opt/cpp/id_set.h +++ b/ortools/math_opt/cpp/id_set.h @@ -217,7 +217,7 @@ IdSet::const_iterator::const_iterator( template IdSet::IdSet(const ModelStorage* storage, StorageType values) - : storage_(storage), set_(std::move(values)) { + : storage_(values.empty() ? nullptr : storage), set_(std::move(values)) { if (!set_.empty()) { CHECK(storage_ != nullptr); } diff --git a/ortools/math_opt/cpp/solution.cc b/ortools/math_opt/cpp/solution.cc index 7bdf43bfcf..b5e348b12c 100644 --- a/ortools/math_opt/cpp/solution.cc +++ b/ortools/math_opt/cpp/solution.cc @@ -20,46 +20,21 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/base/logging.h" -#include "ortools/base/strong_int.h" +#include "ortools/base/status_builder.h" +#include "ortools/base/status_macros.h" #include "ortools/math_opt/core/model_storage.h" #include "ortools/math_opt/core/sparse_vector_view.h" #include "ortools/math_opt/cpp/linear_constraint.h" +#include "ortools/math_opt/cpp/sparse_containers.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/solution.pb.h" #include "ortools/math_opt/sparse_containers.pb.h" +#include "ortools/math_opt/validators/ids_validator.h" +#include "ortools/math_opt/validators/sparse_vector_validator.h" +#include "ortools/util/status_macros.h" namespace operations_research { namespace math_opt { -namespace { - -template -IdMap ValuesFrom(const ModelStorage* const model, - const SparseDoubleVectorProto& vars_proto) { - return IdMap( - model, MakeView(vars_proto).as_map()); -} - -template -IdMap BasisValues( - const ModelStorage* const model, - const SparseBasisStatusVector& basis_proto) { - absl::flat_hash_map id_map; - for (const auto& [id, basis_status_proto] : MakeView(basis_proto)) { - // CHECK fails on BASIS_STATUS_UNSPECIFIED (the validation code should have - // tested that). - // We need to cast because the C++ proto API stores repeated enums as ints. - // - // On top of that iOS 11 does not support .value() on optionals so we must - // use operator*. - const std::optional basis_status = - EnumFromProto(static_cast(basis_status_proto)); - CHECK(basis_status.has_value()); - id_map[static_cast(id)] = *basis_status; - } - return IdMap(model, std::move(id_map)); -} - -} // namespace std::optional Enum::ToOptString( SolutionStatus value) { @@ -83,110 +58,169 @@ absl::Span Enum::AllValues() { return absl::MakeConstSpan(kSolutionStatusValues); } -std::optional Enum::ToOptString( - BasisStatus value) { - switch (value) { - case BasisStatus::kFree: - return "free"; - case BasisStatus::kAtLowerBound: - return "at_lower_bound"; - case BasisStatus::kAtUpperBound: - return "at_upper_bound"; - case BasisStatus::kFixedValue: - return "fixed_value"; - case BasisStatus::kBasic: - return "basic"; - } - return std::nullopt; -} - -absl::Span Enum::AllValues() { - static constexpr BasisStatus kBasisStatusValues[] = { - BasisStatus::kFree, BasisStatus::kAtLowerBound, - BasisStatus::kAtUpperBound, BasisStatus::kFixedValue, - BasisStatus::kBasic, - }; - return absl::MakeConstSpan(kBasisStatusValues); -} - -PrimalSolution PrimalSolution::FromProto( +absl::StatusOr PrimalSolution::FromProto( const ModelStorage* model, const PrimalSolutionProto& primal_solution_proto) { PrimalSolution primal_solution; - primal_solution.variable_values = - ValuesFrom(model, primal_solution_proto.variable_values()); + OR_ASSIGN_OR_RETURN3( + primal_solution.variable_values, + VariableValuesFromProto(model, primal_solution_proto.variable_values()), + _ << "invalid variable_values"); primal_solution.objective_value = primal_solution_proto.objective_value(); - // TODO(b/209014770): consider adding a function to simplify this pattern. const std::optional feasibility_status = EnumFromProto(primal_solution_proto.feasibility_status()); - CHECK(feasibility_status.has_value()); + if (!feasibility_status.has_value()) { + return absl::InvalidArgumentError("feasibility_status must be specified"); + } primal_solution.feasibility_status = *feasibility_status; return primal_solution; } -PrimalRay PrimalRay::FromProto(const ModelStorage* model, - const PrimalRayProto& primal_ray_proto) { - return {.variable_values = - ValuesFrom(model, primal_ray_proto.variable_values())}; +PrimalSolutionProto PrimalSolution::Proto() const { + PrimalSolutionProto result; + *result.mutable_variable_values() = VariableValuesToProto(variable_values); + result.set_objective_value(objective_value); + result.set_feasibility_status(EnumToProto(feasibility_status)); + return result; } -DualSolution DualSolution::FromProto( +absl::StatusOr PrimalRay::FromProto( + const ModelStorage* model, const PrimalRayProto& primal_ray_proto) { + PrimalRay result; + OR_ASSIGN_OR_RETURN3( + result.variable_values, + VariableValuesFromProto(model, primal_ray_proto.variable_values()), + _ << "invalid variable_values"); + return result; +} + +PrimalRayProto PrimalRay::Proto() const { + PrimalRayProto result; + *result.mutable_variable_values() = VariableValuesToProto(variable_values); + return result; +} + +absl::StatusOr DualSolution::FromProto( const ModelStorage* model, const DualSolutionProto& dual_solution_proto) { DualSolution dual_solution; - dual_solution.dual_values = - ValuesFrom(model, dual_solution_proto.dual_values()); - dual_solution.reduced_costs = - ValuesFrom(model, dual_solution_proto.reduced_costs()); + OR_ASSIGN_OR_RETURN3( + dual_solution.dual_values, + LinearConstraintValuesFromProto(model, dual_solution_proto.dual_values()), + _ << "invalid dual_values"); + OR_ASSIGN_OR_RETURN3( + dual_solution.reduced_costs, + VariableValuesFromProto(model, dual_solution_proto.reduced_costs()), + _ << "invalid reduced_costs"); if (dual_solution_proto.has_objective_value()) { dual_solution.objective_value = dual_solution_proto.objective_value(); } - // TODO(b/209014770): consider adding a function to simplify this pattern. const std::optional feasibility_status = EnumFromProto(dual_solution_proto.feasibility_status()); - CHECK(feasibility_status.has_value()); + if (!feasibility_status.has_value()) { + return absl::InvalidArgumentError("feasibility_status must be specified"); + } dual_solution.feasibility_status = *feasibility_status; return dual_solution; } -DualRay DualRay::FromProto(const ModelStorage* model, - const DualRayProto& dual_ray_proto) { - return {.dual_values = - ValuesFrom(model, dual_ray_proto.dual_values()), - .reduced_costs = - ValuesFrom(model, dual_ray_proto.reduced_costs())}; +DualSolutionProto DualSolution::Proto() const { + DualSolutionProto result; + *result.mutable_dual_values() = LinearConstraintValuesToProto(dual_values); + *result.mutable_reduced_costs() = VariableValuesToProto(reduced_costs); + if (objective_value.has_value()) { + result.set_objective_value(*objective_value); + } + result.set_feasibility_status(EnumToProto(feasibility_status)); + return result; } -Basis Basis::FromProto(const ModelStorage* model, - const BasisProto& basis_proto) { +absl::StatusOr DualRay::FromProto(const ModelStorage* model, + const DualRayProto& dual_ray_proto) { + DualRay result; + OR_ASSIGN_OR_RETURN3( + result.dual_values, + LinearConstraintValuesFromProto(model, dual_ray_proto.dual_values()), + _ << "invalid dual_values"); + OR_ASSIGN_OR_RETURN3( + result.reduced_costs, + VariableValuesFromProto(model, dual_ray_proto.reduced_costs()), + _ << "invalid reduced_costs"); + return result; +} + +DualRayProto DualRay::Proto() const { + DualRayProto result; + *result.mutable_dual_values() = LinearConstraintValuesToProto(dual_values); + *result.mutable_reduced_costs() = VariableValuesToProto(reduced_costs); + return result; +} + +absl::StatusOr Basis::FromProto(const ModelStorage* model, + const BasisProto& basis_proto) { Basis basis; - basis.constraint_status = - BasisValues(model, basis_proto.constraint_status()); - basis.variable_status = - BasisValues(model, basis_proto.variable_status()); - // TODO(b/209014770): consider adding a function to simplify this pattern. + OR_ASSIGN_OR_RETURN3( + basis.constraint_status, + LinearConstraintBasisFromProto(model, basis_proto.constraint_status()), + _ << "invalid constraint_status"); + OR_ASSIGN_OR_RETURN3( + basis.variable_status, + VariableBasisFromProto(model, basis_proto.variable_status()), + _ << "invalid variable_status"); const std::optional basic_dual_feasibility = EnumFromProto(basis_proto.basic_dual_feasibility()); - CHECK(basic_dual_feasibility.has_value()); + if (!basic_dual_feasibility.has_value()) { + return absl::InvalidArgumentError( + "basic_dual_feasibility for a basis must be specified"); + } basis.basic_dual_feasibility = *basic_dual_feasibility; return basis; } -Solution Solution::FromProto(const ModelStorage* model, - const SolutionProto& solution_proto) { +BasisProto Basis::Proto() const { + BasisProto result; + *result.mutable_constraint_status() = + LinearConstraintBasisToProto(constraint_status); + *result.mutable_variable_status() = VariableBasisToProto(variable_status); + result.set_basic_dual_feasibility(EnumToProto(basic_dual_feasibility)); + return result; +} + +absl::StatusOr Solution::FromProto( + const ModelStorage* model, const SolutionProto& solution_proto) { Solution solution; if (solution_proto.has_primal_solution()) { - solution.primal_solution = - PrimalSolution::FromProto(model, solution_proto.primal_solution()); + OR_ASSIGN_OR_RETURN3( + solution.primal_solution, + PrimalSolution::FromProto(model, solution_proto.primal_solution()), + _ << "invalid primal_solution"); } if (solution_proto.has_dual_solution()) { - solution.dual_solution = - DualSolution::FromProto(model, solution_proto.dual_solution()); + OR_ASSIGN_OR_RETURN3( + solution.dual_solution, + DualSolution::FromProto(model, solution_proto.dual_solution()), + _ << "invalid dual_solution"); } if (solution_proto.has_basis()) { - solution.basis = Basis::FromProto(model, solution_proto.basis()); + OR_ASSIGN_OR_RETURN3(solution.basis, + Basis::FromProto(model, solution_proto.basis()), + _ << "invalid basis"); } return solution; } +SolutionProto Solution::Proto() const { + SolutionProto result; + if (primal_solution.has_value()) { + *result.mutable_primal_solution() = primal_solution->Proto(); + } + if (dual_solution.has_value()) { + *result.mutable_dual_solution() = dual_solution->Proto(); + } + if (basis.has_value()) { + *result.mutable_basis() = basis->Proto(); + } + return result; +} + } // namespace math_opt } // namespace operations_research diff --git a/ortools/math_opt/cpp/solution.h b/ortools/math_opt/cpp/solution.h index 5c6a12db3e..6554f7ceea 100644 --- a/ortools/math_opt/cpp/solution.h +++ b/ortools/math_opt/cpp/solution.h @@ -20,6 +20,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "ortools/math_opt/core/model_storage.h" +#include "ortools/math_opt/cpp/basis_status.h" #include "ortools/math_opt/cpp/enums.h" // IWYU pragma: export #include "ortools/math_opt/cpp/linear_constraint.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" @@ -43,26 +44,6 @@ enum class SolutionStatus { MATH_OPT_DEFINE_ENUM(SolutionStatus, SOLUTION_STATUS_UNSPECIFIED); -// Status of a variable/constraint in a LP basis. -enum class BasisStatus : int8_t { - // The variable/constraint is free (it has no finite bounds). - kFree = BASIS_STATUS_FREE, - - // The variable/constraint is at its lower bound (which must be finite). - kAtLowerBound = BASIS_STATUS_AT_LOWER_BOUND, - - // The variable/constraint is at its upper bound (which must be finite). - kAtUpperBound = BASIS_STATUS_AT_UPPER_BOUND, - - // The variable/constraint has identical finite lower and upper bounds. - kFixedValue = BASIS_STATUS_FIXED_VALUE, - - // The variable/constraint is basic. - kBasic = BASIS_STATUS_BASIC, -}; - -MATH_OPT_DEFINE_ENUM(BasisStatus, BASIS_STATUS_UNSPECIFIED); - // A solution to an optimization problem. // // E.g. consider a simple linear program: @@ -76,10 +57,18 @@ MATH_OPT_DEFINE_ENUM(BasisStatus, BASIS_STATUS_UNSPECIFIED); // For the general case of a MathOpt optimization model, see // go/mathopt-solutions for details. struct PrimalSolution { - static PrimalSolution FromProto( + // Returns the PrimalSolution equivalent of primal_solution_proto. + // + // Returns an error when: + // * VariableValuesFromProto(primal_solution_proto.variable_values) fails. + // * the feasibility_status is not specified. + static absl::StatusOr FromProto( const ModelStorage* model, const PrimalSolutionProto& primal_solution_proto); + // Returns the proto equivalent of this. + PrimalSolutionProto Proto() const; + VariableMap variable_values; double objective_value = 0.0; @@ -107,8 +96,15 @@ struct PrimalSolution { // For the general case of a MathOpt optimization model, see // go/mathopt-solutions for details. struct PrimalRay { - static PrimalRay FromProto(const ModelStorage* model, - const PrimalRayProto& primal_ray_proto); + // Returns the PrimalRay equivalent of primal_ray_proto. + // + // Returns an error when + // VariableValuesFromProto(primal_ray_proto.variable_values) fails. + static absl::StatusOr FromProto( + const ModelStorage* model, const PrimalRayProto& primal_ray_proto); + + // Returns the proto equivalent of this. + PrimalRayProto Proto() const; VariableMap variable_values; }; @@ -128,8 +124,17 @@ struct PrimalRay { // For the general case, see go/mathopt-solutions and go/mathopt-dual (and // note that the dual objective depends on r in the general case). struct DualSolution { - static DualSolution FromProto(const ModelStorage* model, - const DualSolutionProto& dual_solution_proto); + // Returns the DualSolution equivalent of dual_solution_proto. + // + // Returns an error when any of: + // * VariableValuesFromProto(dual_solution_proto.reduced_costs) fails. + // * LinearConstraintValuesFromProto(dual_solution_proto.dual_values) fails. + // * dual_solution_proto.feasibility_status is not specified. + static absl::StatusOr FromProto( + const ModelStorage* model, const DualSolutionProto& dual_solution_proto); + + // Returns the proto equivalent of this. + DualSolutionProto Proto() const; LinearConstraintMap dual_values; VariableMap reduced_costs; @@ -159,8 +164,16 @@ struct DualSolution { // For the general case, see go/mathopt-solutions and go/mathopt-dual (and // note that the dual objective depends on r in the general case). struct DualRay { - static DualRay FromProto(const ModelStorage* model, - const DualRayProto& dual_ray_proto); + // Returns the DualRay equivalent of dual_ray_proto. + // + // Returns an error when either of: + // * VariableValuesFromProto(dual_ray_proto.reduced_costs) fails. + // * LinearConstraintValuesFromProto(dual_ray_proto.dual_values) fails. + static absl::StatusOr FromProto(const ModelStorage* model, + const DualRayProto& dual_ray_proto); + + // Returns the proto equivalent of this. + DualRayProto Proto() const; LinearConstraintMap dual_values; VariableMap reduced_costs; @@ -193,12 +206,17 @@ struct DualRay { // See go/mathopt-basis for treatment of the general case and an explanation // of how a dual solution is determined for a basis. struct Basis { - // Returns a Basis built from the input indexed_basis, CHECKing that no - // values is BASIS_STATUS_UNSPECIFIED. No check is done on other values so - // out of bounds values e.g. BasisStatusProto_MAX+1 won't raise an - // assertion. See SpaseBasisStatusVectorIsValid(). - static Basis FromProto(const ModelStorage* model, - const BasisProto& basis_proto); + // Returns the equivalent Basis object for basis_proto. + // + // Returns an error if: + // * VariableBasisFromProto(basis_proto.variable_status) fails. + // * LinearConstraintBasisFromProto(basis_proto.constraint_status) fails. + // * basis_proto.basic_dual_feasibility is unspecified. + static absl::StatusOr FromProto(const ModelStorage* model, + const BasisProto& basis_proto); + + // Returns the proto equivalent of this. + BasisProto Proto() const; LinearConstraintMap constraint_status; VariableMap variable_status; @@ -218,8 +236,16 @@ struct Basis { // 3. Other continuous solvers often return a primal and dual solution // solution that are connected in a solver-dependent form. struct Solution { - static Solution FromProto(const ModelStorage* model, - const SolutionProto& solution_proto); + // Returns the Solution equivalent of solution_proto. + // + // Returns an error if FromProto() fails on any field that is not std::nullopt + // (see the static FromProto() functions for each field type for details). + static absl::StatusOr FromProto( + const ModelStorage* model, const SolutionProto& solution_proto); + + // Returns the proto equivalent of this. + SolutionProto Proto() const; + std::optional primal_solution; std::optional dual_solution; std::optional basis; diff --git a/ortools/math_opt/cpp/solve_result.cc b/ortools/math_opt/cpp/solve_result.cc index 24b93fa269..a13a9e0ea7 100644 --- a/ortools/math_opt/cpp/solve_result.cc +++ b/ortools/math_opt/cpp/solve_result.cc @@ -26,35 +26,16 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/base/protoutil.h" +#include "ortools/base/status_macros.h" #include "ortools/math_opt/core/model_storage.h" #include "ortools/math_opt/cpp/linear_constraint.h" #include "ortools/math_opt/cpp/variable_and_expressions.h" #include "ortools/math_opt/solution.pb.h" #include "ortools/port/proto_utils.h" +#include "ortools/util/status_macros.h" namespace operations_research { namespace math_opt { -namespace { - -// Converts a map with BasisStatusProto values to a map with BasisStatus values -// CHECKing that no values are BASIS_STATUS_UNSPECIFIED (the validation code -// should have tested that). -// -// TODO(b/201344491): use FromProto() factory methods on solution members and -// remove the need for this conversion from `IndexedSolutions`. -template -absl::flat_hash_map BasisStatusMapFromProto( - const absl::flat_hash_map& proto_map) { - absl::flat_hash_map cpp_map; - for (const auto& [id, proto_value] : proto_map) { - const std::optional opt_status = EnumFromProto(proto_value); - CHECK(opt_status.has_value()); - cpp_map.emplace(id, *opt_status); - } - return cpp_map; -} - -} // namespace std::optional Enum::ToOptString( FeasibilityStatus value) { @@ -173,12 +154,10 @@ Termination Termination::NoSolutionFound(const Limit limit, return termination; } -TerminationProto Termination::ToProto() const { +TerminationProto Termination::Proto() const { TerminationProto proto; proto.set_reason(EnumToProto(reason)); - if (limit.has_value()) { - proto.set_limit(EnumToProto(*limit)); - } + proto.set_limit(EnumToProto(limit)); proto.set_detail(detail); return proto; } @@ -188,32 +167,16 @@ bool Termination::limit_reached() const { reason == TerminationReason::kNoSolutionFound; } -Termination Termination::FromProto(const TerminationProto& termination_proto) { - const bool limit_reached = - termination_proto.reason() == TERMINATION_REASON_FEASIBLE || - termination_proto.reason() == TERMINATION_REASON_NO_SOLUTION_FOUND; - const bool has_limit = termination_proto.limit() != LIMIT_UNSPECIFIED; - CHECK_EQ(limit_reached, has_limit) - << "Termination reason should be TERMINATION_REASON_FEASIBLE or " - "TERMINATION_REASON_NO_SOLUTION_FOUND if and only if limit is " - "specified, but found reason=" - << ProtoEnumToString(termination_proto.reason()) - << " and limit=" << ProtoEnumToString(termination_proto.limit()); - - if (has_limit) { - const std::optional opt_limit = - EnumFromProto(termination_proto.limit()); - CHECK(opt_limit.has_value()); - if (termination_proto.reason() == TERMINATION_REASON_FEASIBLE) { - return Feasible(*opt_limit, termination_proto.detail()); - } - return NoSolutionFound(*opt_limit, termination_proto.detail()); - } - - const std::optional opt_reason = +absl::StatusOr Termination::FromProto( + const TerminationProto& termination_proto) { + const std::optional reason = EnumFromProto(termination_proto.reason()); - CHECK(opt_reason.has_value()); - return Termination(*opt_reason, termination_proto.detail()); + if (!reason.has_value()) { + return absl::InvalidArgumentError("reason must be specified"); + } + Termination result(*reason, termination_proto.detail()); + result.limit = EnumFromProto(termination_proto.limit()); + return result; } std::ostream& operator<<(std::ostream& ostr, const Termination& termination) { @@ -235,7 +198,7 @@ std::string Termination::ToString() const { return stream.str(); } -ProblemStatusProto ProblemStatus::ToProto() const { +ProblemStatusProto ProblemStatus::Proto() const { ProblemStatusProto proto; proto.set_primal_status(EnumToProto(primal_status)); proto.set_dual_status(EnumToProto(dual_status)); @@ -243,21 +206,22 @@ ProblemStatusProto ProblemStatus::ToProto() const { return proto; } -ProblemStatus ProblemStatus::FromProto( +absl::StatusOr ProblemStatus::FromProto( const ProblemStatusProto& problem_status_proto) { - ProblemStatus result; - // TODO(b/209014770): consider adding a function to simplify this pattern. - const std::optional opt_primal_status = + const std::optional primal_status = EnumFromProto(problem_status_proto.primal_status()); - const std::optional opt_dual_status = + if (!primal_status.has_value()) { + return absl::InvalidArgumentError("primal_status must be specified"); + } + const std::optional dual_status = EnumFromProto(problem_status_proto.dual_status()); - CHECK(opt_primal_status.has_value()); - CHECK(opt_dual_status.has_value()); - result.primal_status = *opt_primal_status; - result.dual_status = *opt_dual_status; - result.primal_or_dual_infeasible = - problem_status_proto.primal_or_dual_infeasible(); - return result; + if (!dual_status.has_value()) { + return absl::InvalidArgumentError("dual_status must be specified"); + } + return ProblemStatus{.primal_status = *primal_status, + .dual_status = *dual_status, + .primal_or_dual_infeasible = + problem_status_proto.primal_or_dual_infeasible()}; } std::ostream& operator<<(std::ostream& ostr, @@ -276,13 +240,14 @@ std::string ProblemStatus::ToString() const { return stream.str(); } -SolveStatsProto SolveStats::ToProto() const { +absl::StatusOr SolveStats::Proto() const { SolveStatsProto proto; - CHECK_OK( - util_time::EncodeGoogleApiProto(solve_time, proto.mutable_solve_time())); + RETURN_IF_ERROR( + util_time::EncodeGoogleApiProto(solve_time, proto.mutable_solve_time())) + << "invalid solve_time (value must be finite)"; proto.set_best_primal_bound(best_primal_bound); proto.set_best_dual_bound(best_dual_bound); - *proto.mutable_problem_status() = problem_status.ToProto(); + *proto.mutable_problem_status() = problem_status.Proto(); proto.set_simplex_iterations(simplex_iterations); proto.set_barrier_iterations(barrier_iterations); proto.set_first_order_iterations(first_order_iterations); @@ -290,14 +255,19 @@ SolveStatsProto SolveStats::ToProto() const { return proto; } -SolveStats SolveStats::FromProto(const SolveStatsProto& solve_stats_proto) { +absl::StatusOr SolveStats::FromProto( + const SolveStatsProto& solve_stats_proto) { SolveStats result; - result.solve_time = - util_time::DecodeGoogleApiProto(solve_stats_proto.solve_time()).value(); + OR_ASSIGN_OR_RETURN3( + result.solve_time, + util_time::DecodeGoogleApiProto(solve_stats_proto.solve_time()), + _ << "invalid solve_time"); result.best_primal_bound = solve_stats_proto.best_primal_bound(); result.best_dual_bound = solve_stats_proto.best_dual_bound(); - result.problem_status = - ProblemStatus::FromProto(solve_stats_proto.problem_status()); + OR_ASSIGN_OR_RETURN3( + result.problem_status, + ProblemStatus::FromProto(solve_stats_proto.problem_status()), + _ << "invalid problem_status"); result.simplex_iterations = solve_stats_proto.simplex_iterations(); result.barrier_iterations = solve_stats_proto.barrier_iterations(); result.first_order_iterations = solve_stats_proto.first_order_iterations(); @@ -324,27 +294,80 @@ std::string SolveStats::ToString() const { return stream.str(); } -SolveResult SolveResult::FromProto(const ModelStorage* model, - const SolveResultProto& solve_result_proto) { - SolveResult result(Termination::FromProto(solve_result_proto.termination())); - result.solve_stats = SolveStats::FromProto(solve_result_proto.solve_stats()); +absl::Status CheckSolverSpecificOutputEmpty(const SolveResultProto& result) { + if (result.solver_specific_output_case() == + SolveResultProto::SOLVER_SPECIFIC_OUTPUT_NOT_SET) { + return absl::OkStatus(); + } + return util::InvalidArgumentErrorBuilder() + << "cannot set solver specific output twice, was already " + << static_cast(result.solver_specific_output_case()); +} - for (const SolutionProto& solution : solve_result_proto.solutions()) { - result.solutions.push_back(Solution::FromProto(model, solution)); +absl::StatusOr SolveResult::Proto() const { + SolveResultProto result; + *result.mutable_termination() = termination.Proto(); + OR_ASSIGN_OR_RETURN3(*result.mutable_solve_stats(), solve_stats.Proto(), + _ << "invalid solve_stats"); + for (const Solution& solution : solutions) { + *result.add_solutions() = solution.Proto(); } - for (const PrimalRayProto& primal_ray : solve_result_proto.primal_rays()) { - result.primal_rays.push_back(PrimalRay::FromProto(model, primal_ray)); + for (const PrimalRay& primal_ray : primal_rays) { + *result.add_primal_rays() = primal_ray.Proto(); } - for (const DualRayProto& dual_ray : solve_result_proto.dual_rays()) { - result.dual_rays.push_back(DualRay::FromProto(model, dual_ray)); + for (const DualRay& dual_ray : dual_rays) { + *result.add_dual_rays() = dual_ray.Proto(); } - if (solve_result_proto.has_gscip_output()) { - result.gscip_solver_specific_output = - std::move(solve_result_proto.gscip_output()); + // See yaqs/5107601535926272 on checking if a proto is empty. + if (gscip_solver_specific_output.ByteSizeLong() > 0) { + *result.mutable_gscip_output() = gscip_solver_specific_output; } return result; } +absl::StatusOr SolveResult::FromProto( + const ModelStorage* model, const SolveResultProto& solve_result_proto) { + OR_ASSIGN_OR_RETURN3(auto termination, + Termination::FromProto(solve_result_proto.termination()), + _ << "invalid termination"); + SolveResult result(std::move(termination)); + OR_ASSIGN_OR_RETURN3(result.solve_stats, + SolveStats::FromProto(solve_result_proto.solve_stats()), + _ << "invalid solve_stats"); + + for (int i = 0; i < solve_result_proto.solutions_size(); ++i) { + OR_ASSIGN_OR_RETURN3( + auto solution, + Solution::FromProto(model, solve_result_proto.solutions(i)), + _ << "invalid solution at index " << i); + result.solutions.push_back(std::move(solution)); + } + for (int i = 0; i < solve_result_proto.primal_rays_size(); ++i) { + OR_ASSIGN_OR_RETURN3( + auto primal_ray, + PrimalRay::FromProto(model, solve_result_proto.primal_rays(i)), + _ << "invalid primal ray at index " << i); + result.primal_rays.push_back(std::move(primal_ray)); + } + for (int i = 0; i < solve_result_proto.dual_rays_size(); ++i) { + OR_ASSIGN_OR_RETURN3( + auto dual_ray, + DualRay::FromProto(model, solve_result_proto.dual_rays(i)), + _ << "invalid dual ray at index " << i); + result.dual_rays.push_back(std::move(dual_ray)); + } + switch (solve_result_proto.solver_specific_output_case()) { + case SolveResultProto::kGscipOutput: + result.gscip_solver_specific_output = solve_result_proto.gscip_output(); + return result; + case SolveResultProto::SOLVER_SPECIFIC_OUTPUT_NOT_SET: + return result; + } + return util::InvalidArgumentErrorBuilder() + << "unexpected value of solver_specific_output_case " + << solve_result_proto.solver_specific_output_case(); +} + bool SolveResult::has_primal_feasible_solution() const { return !solutions.empty() && solutions[0].primal_solution.has_value() && (solutions[0].primal_solution->feasibility_status == diff --git a/ortools/math_opt/cpp/solve_result.h b/ortools/math_opt/cpp/solve_result.h index 6a09e5b5d5..751133fd73 100644 --- a/ortools/math_opt/cpp/solve_result.h +++ b/ortools/math_opt/cpp/solve_result.h @@ -19,6 +19,7 @@ #include #include +#include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "ortools/base/logging.h" @@ -79,10 +80,11 @@ struct ProblemStatus { // infeasibility, unboundedness, or both). bool primal_or_dual_infeasible = false; - static ProblemStatus FromProto( + // Returns an error if the primal_status or dual_status is unspecified. + static absl::StatusOr FromProto( const ProblemStatusProto& problem_status_proto); - ProblemStatusProto ToProto() const; + ProblemStatusProto Proto() const; std::string ToString() const; }; @@ -108,7 +110,6 @@ struct SolveStats { // may be non-trivial even when no primal feasible solutions are returned. // * best_dual_bound is always better (smaller for minimization and larger // for maximization) than best_primal_bound. - double best_primal_bound = 0.0; // Solver claims the optimal value is equal or worse (larger for @@ -137,10 +138,12 @@ struct SolveStats { int node_count = 0; - // Will CHECK fail on invalid input, if problem_status is invalid. - static SolveStats FromProto(const SolveStatsProto& solve_stats_proto); + // Returns an error if converting the problem_status or solve_time fails. + static absl::StatusOr FromProto( + const SolveStatsProto& solve_stats_proto); - SolveStatsProto ToProto() const; + // Will return an error if solve_time is not finite. + absl::StatusOr Proto() const; std::string ToString() const; }; @@ -258,9 +261,15 @@ struct Termination { // functions Feasible and NoSolutionFound. explicit Termination(TerminationReason reason, std::string detail = {}); + // Additional information in `limit` when value is kFeasible or + // kNoSolutionFound, see `limit` for details. TerminationReason reason; - // Is set iff reason is kFeasible or kNoSolutionFound. + // A Termination within a SolveResult returned by math_opt::Solve() satisfies + // some additional invariants: + // * limit is set iff reason is kFeasible or kNoSolutionFound. + // * if the limit is kCutoff, the termination reason will be + // kNoSolutionFound. std::optional limit; // Additional typically solver specific information about termination. @@ -272,20 +281,16 @@ struct Termination { // kNoSolutionFound, and limit is not empty). bool limit_reached() const; - // Will CHECK fail on invalid input, if reason is unspecified, if limit is - // set when reason is not TERMINATION_REASON_FEASIBLE or - // TERMINATION_REASON_NO_SOLUTION_FOUND, or if limit is unspecified when - // reason is TERMINATION_REASON_FEASIBLE or - // TERMINATION_REASON_NO_SOLUTION_FOUND (see solution_validator.h). - static Termination FromProto(const TerminationProto& termination_proto); - // Sets the reason to kFeasible static Termination Feasible(Limit limit, std::string detail = {}); // Sets the reason to kNoSolutionFound static Termination NoSolutionFound(Limit limit, std::string detail = {}); - TerminationProto ToProto() const; + // Will return an error if termination_proto.reason is UNSPECIFIED. + static absl::StatusOr FromProto( + const TerminationProto& termination_proto); + TerminationProto Proto() const; std::string ToString() const; }; @@ -332,15 +337,43 @@ struct SolveResult { // Solver specific output from Gscip. Only populated if Gscip is used. GScipOutput gscip_solver_specific_output; - static SolveResult FromProto(const ModelStorage* model, - const SolveResultProto& solve_result_proto); + // Returns the SolveResult equivalent of solve_result_proto. + // + // Returns an error if: + // * Any solution or ray cannot be read from proto (e.g. on a subfield, + // ids.size != values.size). + // * termination or solve_result cannot be read from proto. + // See the FromProto() functions for these types for details. + // + // Note: this is (intentionally) a much weaker test than ValidateResult(). The + // guarantees are just strong enough to ensure that a SolveResult and + // SolveResultProto can round trip cleanly, e.g. we do not check that a + // termination reason optimal implies that there is at least one primal + // feasible solution. + // + // While ValidateResult() is called automatically when you are solving + // locally, users who are reading a solution from disk, solving remotely, or + // getting their SolveResultProto (or SolveResult) by any other means are + // encouraged to either call ValidateResult() themselves, do their own + // validation, or not rely on the strong guarantees of ValidateResult() + // and just treat SolveResult as a simple struct. + static absl::StatusOr FromProto( + const ModelStorage* model, const SolveResultProto& solve_result_proto); + + // Returns the proto equivalent of this. + // + // Note that the proto uses a oneof for solver specific output. This method + // will fail if multiple solver specific outputs are set. TODO(b/231134639): + // investigate removing the oneof from the proto. + absl::StatusOr Proto() const; absl::Duration solve_time() const { return solve_stats.solve_time; } // Indicates if at least one primal feasible solution is available. // - // When termination.reason is TerminationReason::kOptimal, this is guaranteed - // to be true and need not be checked. + // When termination.reason is TerminationReason::kOptimal or + // TerminationReason::kFeasible, this is guaranteed to be true and need not be + // checked. bool has_primal_feasible_solution() const; // The objective value of the best primal feasible solution. Will CHECK fail @@ -367,20 +400,27 @@ struct SolveResult { // are no primal rays. const VariableMap& ray_variable_values() const; - // Indicates if the best primal solution has an associated dual feasible - // solution. + // Indicates if the best solution has an associated dual feasible solution. // // This is NOT guaranteed to be true when termination.reason is - // TerminationReason::kOptimal. It also may be true even when the best primal - // solution is not feasible. + // TerminationReason::kOptimal. It also may be true even when the best + // solution does not have an associated primal feasible solution. bool has_dual_feasible_solution() const; - // The dual values from the best dual solution. Will CHECK fail if there - // are no dual solutions. + // The dual values associated to the best solution. + // + // If there is at least one primal feasible solution, this corresponds to the + // dual values associated to the best primal feasible solution. Will CHECK + // fail if the best solution does not have an associated dual feasible + // solution. const LinearConstraintMap& dual_values() const; - // The reduced from the best dual solution. Will CHECK fail if there - // are no dual solutions. + // The reduced costs associated to the best solution. + // + // If there is at least one primal feasible solution, this corresponds to the + // reduced costs associated to the best primal feasible solution. Will CHECK + // fail if the best solution does not have an associated dual feasible + // solution. const VariableMap& reduced_costs() const; // Indicates if at least one dual ray is available. @@ -397,13 +437,13 @@ struct SolveResult { // are no dual rays. const VariableMap& ray_reduced_costs() const; - // Indicates if at least one basis is available. + // Indicates if the best solution has an associated basis. bool has_basis() const; - // The constraint basis status for the first primal/dual pair. + // The constraint basis status for the best solution. const LinearConstraintMap& constraint_status() const; - // The variable basis status for the first primal/dual pair. + // The variable basis status for the best solution. const VariableMap& variable_status() const; }; diff --git a/ortools/math_opt/cpp/sparse_containers.cc b/ortools/math_opt/cpp/sparse_containers.cc new file mode 100644 index 0000000000..ddcd125023 --- /dev/null +++ b/ortools/math_opt/cpp/sparse_containers.cc @@ -0,0 +1,154 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/math_opt/cpp/sparse_containers.h" + +namespace operations_research::math_opt { +namespace { + +// SparseVectorProtoType should be SparseDoubleVector or SparseBasisStatusVector +template +absl::Status CheckSparseVectorProto(const SparseVectorProtoType& vec) { + RETURN_IF_ERROR(CheckIdsAndValuesSize(MakeView(vec))); + RETURN_IF_ERROR(CheckIdsRangeAndStrictlyIncreasing(vec.ids())); + return absl::OkStatus(); +} + +template +absl::StatusOr> BasisVectorFromProto( + const ModelStorage* const model, + const SparseBasisStatusVector& basis_proto) { + using IdType = typename Key::IdType; + absl::flat_hash_map raw_map; + raw_map.reserve(basis_proto.ids_size()); + for (const auto& [id, basis_status_proto_int] : MakeView(basis_proto)) { + const auto basis_status_proto = + static_cast(basis_status_proto_int); + const std::optional basis_status = + EnumFromProto(basis_status_proto); + if (!basis_status.has_value()) { + return util::InvalidArgumentErrorBuilder() + << "basis status not specified for id " << id; + } + raw_map[IdType(id)] = *basis_status; + } + return IdMap(model, std::move(raw_map)); +} + +template +SparseDoubleVectorProto IdMapToProto(const IdMap& id_map) { + using IdType = typename Key::IdType; + SparseDoubleVectorProto result; + std::vector> sorted_entries( + id_map.raw_map().begin(), id_map.raw_map().end()); + std::sort(sorted_entries.begin(), sorted_entries.end()); + for (const auto& [id, val] : sorted_entries) { + result.add_ids(id.value()); + result.add_values(val); + } + return result; +} + +template +SparseBasisStatusVector BasisIdMapToProto( + const IdMap& basis_map) { + using IdType = typename Key::IdType; + SparseBasisStatusVector result; + std::vector> sorted_entries( + basis_map.raw_map().begin(), basis_map.raw_map().end()); + std::sort(sorted_entries.begin(), sorted_entries.end()); + for (const auto& [id, val] : sorted_entries) { + result.add_ids(id.value()); + result.add_values(EnumToProto(val)); + } + return result; +} + +absl::Status VariableIdsExist(const ModelStorage* const model, + const absl::Span ids) { + for (const int64_t id : ids) { + if (!model->has_variable(VariableId(id))) { + return util::InvalidArgumentErrorBuilder() + << "no variable with id " << id << " exists"; + } + } + return absl::OkStatus(); +} + +absl::Status LinearConstraintIdsExist(const ModelStorage* const model, + const absl::Span ids) { + for (const int64_t id : ids) { + if (!model->has_linear_constraint(LinearConstraintId(id))) { + return util::InvalidArgumentErrorBuilder() + << "no linear constraint with id " << id << " exists"; + } + } + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> VariableValuesFromProto( + const ModelStorage* const model, + const SparseDoubleVectorProto& vars_proto) { + RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto)); + RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids())); + return VariableMap(model, MakeView(vars_proto).as_map()); +} + +SparseDoubleVectorProto VariableValuesToProto( + const VariableMap& variable_values) { + return IdMapToProto(variable_values); +} + +absl::StatusOr> LinearConstraintValuesFromProto( + const ModelStorage* const model, + const SparseDoubleVectorProto& lin_cons_proto) { + RETURN_IF_ERROR(CheckSparseVectorProto(lin_cons_proto)); + RETURN_IF_ERROR(LinearConstraintIdsExist(model, lin_cons_proto.ids())); + return LinearConstraintMap( + model, MakeView(lin_cons_proto).as_map()); +} + +SparseDoubleVectorProto LinearConstraintValuesToProto( + const LinearConstraintMap& linear_constraint_values) { + return IdMapToProto(linear_constraint_values); +} + +absl::StatusOr> VariableBasisFromProto( + const ModelStorage* const model, + const SparseBasisStatusVector& basis_proto) { + RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto)); + RETURN_IF_ERROR(VariableIdsExist(model, basis_proto.ids())); + return BasisVectorFromProto(model, basis_proto); +} + +SparseBasisStatusVector VariableBasisToProto( + const VariableMap& basis_values) { + return BasisIdMapToProto(basis_values); +} + +absl::StatusOr> LinearConstraintBasisFromProto( + const ModelStorage* const model, + const SparseBasisStatusVector& basis_proto) { + RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto)); + RETURN_IF_ERROR(LinearConstraintIdsExist(model, basis_proto.ids())); + return BasisVectorFromProto(model, basis_proto); +} + +SparseBasisStatusVector LinearConstraintBasisToProto( + const LinearConstraintMap& basis_values) { + return BasisIdMapToProto(basis_values); +} + +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/cpp/sparse_containers.h b/ortools/math_opt/cpp/sparse_containers.h new file mode 100644 index 0000000000..d9c5757111 --- /dev/null +++ b/ortools/math_opt/cpp/sparse_containers.h @@ -0,0 +1,104 @@ +// Copyright 2010-2021 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_MATH_OPT_CPP_SPARSE_CONTAINERS_H_ +#define OR_TOOLS_MATH_OPT_CPP_SPARSE_CONTAINERS_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/base/status_builder.h" +#include "ortools/base/status_macros.h" +#include "ortools/math_opt/core/model_storage.h" +#include "ortools/math_opt/core/sparse_vector_view.h" +#include "ortools/math_opt/cpp/basis_status.h" +#include "ortools/math_opt/cpp/linear_constraint.h" +#include "ortools/math_opt/cpp/variable_and_expressions.h" +#include "ortools/math_opt/solution.pb.h" +#include "ortools/math_opt/sparse_containers.pb.h" +#include "ortools/math_opt/validators/ids_validator.h" +#include "ortools/math_opt/validators/sparse_vector_validator.h" +#include "ortools/util/status_macros.h" + +namespace operations_research::math_opt { + +// Returns the VariableMap equivalent to `vars_proto`. +// +// Requires that (or returns a status error): +// * vars_proto.ids and vars_proto.values have equal size. +// * vars_proto.ids is sorted. +// * vars_proto.ids has elements in [0, max(int64_t)). +// * vars_proto.ids has elements that are variables in `model`. +// +// Note that the values of vars_proto.values are not checked (it may have NaNs). +absl::StatusOr> VariableValuesFromProto( + const ModelStorage* const model, const SparseDoubleVectorProto& vars_proto); + +// Returns the proto equivalent of variable_values. +SparseDoubleVectorProto VariableValuesToProto( + const VariableMap& variable_values); + +// Returns the LinearConstraintMap equivalent to `lin_cons_proto`. +// +// Requires that (or returns a status error): +// * lin_cons_proto.ids and lin_cons_proto.values have equal size. +// * lin_cons_proto.ids is sorted. +// * lin_cons_proto.ids has elements in [0, max(int64_t)). +// * lin_cons_proto.ids has elements that are linear constraints in `model`. +// +// Note that the values of lin_cons_proto.values are not checked (it may have +// NaNs). +absl::StatusOr> LinearConstraintValuesFromProto( + const ModelStorage* const model, + const SparseDoubleVectorProto& lin_cons_proto); + +// Returns the proto equivalent of linear_constraint_values. +SparseDoubleVectorProto LinearConstraintValuesToProto( + const LinearConstraintMap& linear_constraint_values); + +// Returns the VariableMap equivalent to `basis_proto`. +// +// Requires that (or returns a status error): +// * basis_proto.ids and basis_proto.values have equal size. +// * basis_proto.ids is sorted. +// * basis_proto.ids has elements in [0, max(int64_t)). +// * basis_proto.ids has elements that are variables in `model`. +// * basis_proto.values does not contain UNSPECIFIED and has valid enum values. +absl::StatusOr> VariableBasisFromProto( + const ModelStorage* const model, + const SparseBasisStatusVector& basis_proto); + +// Returns the proto equivalent of basis_values. +SparseBasisStatusVector VariableBasisToProto( + const VariableMap& basis_values); + +// Returns the LinearConstraintMap equivalent to `basis_proto`. +// +// Requires that (or returns a status error): +// * basis_proto.ids and basis_proto.values have equal size. +// * basis_proto.ids is sorted. +// * basis_proto.ids has elements in [0, max(int64_t)). +// * basis_proto.ids has elements that are linear constraints in `model`. +// * basis_proto.values does not contain UNSPECIFIED and has valid enum values. +absl::StatusOr> LinearConstraintBasisFromProto( + const ModelStorage* const model, + const SparseBasisStatusVector& basis_proto); + +// Returns the proto equivalent of basis_values. +SparseBasisStatusVector LinearConstraintBasisToProto( + const LinearConstraintMap& basis_values); + +} // namespace operations_research::math_opt + +#endif // OR_TOOLS_MATH_OPT_CPP_SPARSE_CONTAINERS_H_ diff --git a/ortools/math_opt/cpp/variable_and_expressions.cc b/ortools/math_opt/cpp/variable_and_expressions.cc index f1ff75c159..925f92d74a 100644 --- a/ortools/math_opt/cpp/variable_and_expressions.cc +++ b/ortools/math_opt/cpp/variable_and_expressions.cc @@ -267,6 +267,24 @@ std::ostream& operator<<(std::ostream& ostr, const QuadraticExpression& expr) { return ostr; } +std::ostream& operator<<(std::ostream& ostr, + const BoundedQuadraticExpression& bounded_expression) { + // TODO(b/170991498): use bijective conversion from double to base-10 string + // to make sure we can reproduce bugs. + const double lb = bounded_expression.lower_bound; + const double ub = bounded_expression.upper_bound; + if (lb == ub) { + ostr << bounded_expression.expression << " = " << lb; + } else if (lb == -kInf) { + ostr << bounded_expression.expression << " ≤ " << ub; + } else if (ub == kInf) { + ostr << bounded_expression.expression << " ≥ " << lb; + } else { + ostr << lb << " ≤ " << bounded_expression.expression << " ≤ " << ub; + } + return ostr; +} + #ifdef MATH_OPT_USE_EXPRESSION_COUNTERS QuadraticExpression::QuadraticExpression() { ++num_calls_default_constructor_; } diff --git a/ortools/math_opt/cpp/variable_and_expressions.h b/ortools/math_opt/cpp/variable_and_expressions.h index ace3dfb8a3..1b388bad7a 100644 --- a/ortools/math_opt/cpp/variable_and_expressions.h +++ b/ortools/math_opt/cpp/variable_and_expressions.h @@ -477,8 +477,8 @@ inline bool operator!=(const Variable& lhs, const Variable& rhs); // A LinearExpression with a lower bound. struct LowerBoundedLinearExpression { - // Users are not expected to use this constructor. Instead they should build - // this object using the overloads of >= and <= operators. For example `x + y + // Users are not expected to use this constructor. Instead, they should build + // this object using overloads of the >= and <= operators. For example, `x + y // >= 3`. inline LowerBoundedLinearExpression(LinearExpression expression, double lower_bound); @@ -489,7 +489,7 @@ struct LowerBoundedLinearExpression { // A LinearExpression with an upper bound. struct UpperBoundedLinearExpression { // Users are not expected to use this constructor. Instead they should build - // this object using the overloads of >= and <= operators. For example `x + y + // this object using overloads of the >= and <= operators. For example, `x + y // <= 3`. inline UpperBoundedLinearExpression(LinearExpression expression, double upper_bound); @@ -500,8 +500,8 @@ struct UpperBoundedLinearExpression { // A LinearExpression with upper and lower bounds. struct BoundedLinearExpression { // Users are not expected to use this constructor. Instead they should build - // this object using the overloads of >= and <= operators. For example `3 <= x - // + y <= 3`. + // this object using overloads of the >=, <=, and == operators. For example, + // `3 <= x + y <= 3`. inline BoundedLinearExpression(LinearExpression expression, double lower_bound, double upper_bound); // Users are not expected to use this constructor. This implicit conversion @@ -1019,6 +1019,207 @@ inline QuadraticExpression operator*(QuadraticExpression lhs, double rhs); inline QuadraticExpression operator/(QuadraticExpression lhs, double rhs); +// A QuadraticExpression with a lower bound. +struct LowerBoundedQuadraticExpression { + // Users are not expected to use this constructor. Instead, they should build + // this object using overloads of the >= and <= operators. For example, `x * y + // >= 3`. + inline LowerBoundedQuadraticExpression(QuadraticExpression expression, + double lower_bound); + // Users are not expected to explicitly use the following constructor. + inline LowerBoundedQuadraticExpression( // NOLINT + LowerBoundedLinearExpression lb_expression); + + QuadraticExpression expression; + double lower_bound; +}; + +// A QuadraticExpression with an upper bound. +struct UpperBoundedQuadraticExpression { + // Users are not expected to use this constructor. Instead, they should build + // this object using overloads of the >= and <= operators. For example, `x * y + // <= 3`. + inline UpperBoundedQuadraticExpression(QuadraticExpression expression, + double upper_bound); + // Users are not expected to explicitly use the following constructor. + inline UpperBoundedQuadraticExpression( // NOLINT + UpperBoundedLinearExpression ub_expression); + + QuadraticExpression expression; + double upper_bound; +}; + +// A QuadraticExpression with upper and lower bounds. +struct BoundedQuadraticExpression { + // Users are not expected to use this constructor. Instead, they should build + // this object using overloads of the >=, <=, and == operators. For example, + // `3 <= x * y <= 3`. + inline BoundedQuadraticExpression(QuadraticExpression expression, + double lower_bound, double upper_bound); + + // Users are not expected to explicitly use the following constructors. + inline BoundedQuadraticExpression( // NOLINT + internal::VariablesEquality var_equality); + inline BoundedQuadraticExpression( // NOLINT + LowerBoundedLinearExpression lb_expression); + inline BoundedQuadraticExpression( // NOLINT + UpperBoundedLinearExpression ub_expression); + inline BoundedQuadraticExpression( // NOLINT + BoundedLinearExpression bounded_expression); + inline BoundedQuadraticExpression( // NOLINT + LowerBoundedQuadraticExpression lb_expression); + inline BoundedQuadraticExpression( // NOLINT + UpperBoundedQuadraticExpression ub_expression); + + // Returns the actual lower_bound after taking into account the quadratic + // expression offset. + inline double lower_bound_minus_offset() const; + // Returns the actual upper_bound after taking into account the quadratic + // expression offset. + inline double upper_bound_minus_offset() const; + + QuadraticExpression expression; + double lower_bound; + double upper_bound; +}; + +std::ostream& operator<<(std::ostream& ostr, + const BoundedQuadraticExpression& bounded_expression); + +// We intentionally pass the QuadraticExpression argument by value so that we +// don't make unnecessary copies of temporary objects by using the move +// constructor and the returned values optimization (RVO). +inline LowerBoundedQuadraticExpression operator>=(QuadraticExpression lhs, + double rhs); +inline LowerBoundedQuadraticExpression operator>=(QuadraticTerm lhs, + double rhs); +inline LowerBoundedQuadraticExpression operator<=(double lhs, + QuadraticExpression rhs); +inline LowerBoundedQuadraticExpression operator<=(double lhs, + QuadraticTerm rhs); + +inline UpperBoundedQuadraticExpression operator>=(double lhs, + QuadraticExpression rhs); +inline UpperBoundedQuadraticExpression operator>=(double lhs, + QuadraticTerm rhs); +inline UpperBoundedQuadraticExpression operator<=(QuadraticExpression lhs, + double rhs); +inline UpperBoundedQuadraticExpression operator<=(QuadraticTerm lhs, + double rhs); + +// We intentionally pass the UpperBoundedQuadraticExpression and +// LowerBoundedQuadraticExpression arguments by value so that we don't +// make unnecessary copies of temporary objects by using the move constructor +// and the returned values optimization (RVO). +inline BoundedQuadraticExpression operator>=( + UpperBoundedQuadraticExpression lhs, double rhs); +inline BoundedQuadraticExpression operator>=( + double lhs, LowerBoundedQuadraticExpression rhs); +inline BoundedQuadraticExpression operator<=( + LowerBoundedQuadraticExpression lhs, double rhs); +inline BoundedQuadraticExpression operator<=( + double lhs, UpperBoundedQuadraticExpression rhs); +// We intentionally pass one QuadraticExpression argument by value so that we +// don't make unnecessary copies of temporary objects by using the move +// constructor and the returned values optimization (RVO). + +// Comparisons with lhs = QuadraticExpression +inline BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const QuadraticExpression& rhs); +inline BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const LinearExpression& rhs); +inline BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + LinearTerm rhs); +inline BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + Variable rhs); +inline BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const QuadraticExpression& rhs); +inline BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const LinearExpression& rhs); +inline BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + LinearTerm rhs); +inline BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + Variable rhs); +inline BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const QuadraticExpression& rhs); +inline BoundedQuadraticExpression operator==(QuadraticExpression lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const LinearExpression& rhs); +inline BoundedQuadraticExpression operator==(QuadraticExpression lhs, + LinearTerm rhs); +inline BoundedQuadraticExpression operator==(QuadraticExpression lhs, + Variable rhs); +inline BoundedQuadraticExpression operator==(QuadraticExpression lhs, + double rhs); +// Comparisons with lhs = QuadraticTerm +inline BoundedQuadraticExpression operator>=(QuadraticTerm lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator>=(QuadraticTerm lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator>=(QuadraticTerm lhs, + LinearExpression rhs); +inline BoundedQuadraticExpression operator>=(QuadraticTerm lhs, LinearTerm rhs); +inline BoundedQuadraticExpression operator>=(QuadraticTerm lhs, Variable rhs); +inline BoundedQuadraticExpression operator<=(QuadraticTerm lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator<=(QuadraticTerm lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator<=(QuadraticTerm lhs, + LinearExpression rhs); +inline BoundedQuadraticExpression operator<=(QuadraticTerm lhs, LinearTerm rhs); +inline BoundedQuadraticExpression operator<=(QuadraticTerm lhs, Variable rhs); +inline BoundedQuadraticExpression operator==(QuadraticTerm lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator==(QuadraticTerm lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator==(QuadraticTerm lhs, + LinearExpression rhs); +inline BoundedQuadraticExpression operator==(QuadraticTerm lhs, LinearTerm rhs); +inline BoundedQuadraticExpression operator==(QuadraticTerm lhs, Variable rhs); +inline BoundedQuadraticExpression operator==(QuadraticTerm lhs, double rhs); +// Comparisons with lhs = LinearExpression +inline BoundedQuadraticExpression operator>=(const LinearExpression& lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator>=(LinearExpression lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator<=(const LinearExpression& lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator<=(LinearExpression lhs, + QuadraticTerm rhs); +inline BoundedQuadraticExpression operator==(const LinearExpression& lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator==(LinearExpression lhs, + QuadraticTerm rhs); +// Comparisons with lhs = LinearTerm +inline BoundedQuadraticExpression operator>=(LinearTerm lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator>=(LinearTerm lhs, QuadraticTerm rhs); +inline BoundedQuadraticExpression operator<=(LinearTerm lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator<=(LinearTerm lhs, QuadraticTerm rhs); +inline BoundedQuadraticExpression operator==(LinearTerm lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator==(LinearTerm lhs, QuadraticTerm rhs); +// Comparisons with lhs = Variable +inline BoundedQuadraticExpression operator>=(Variable lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator>=(Variable lhs, QuadraticTerm rhs); +inline BoundedQuadraticExpression operator<=(Variable lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator<=(Variable lhs, QuadraticTerm rhs); +inline BoundedQuadraticExpression operator==(Variable lhs, + QuadraticExpression rhs); +inline BoundedQuadraticExpression operator==(Variable lhs, QuadraticTerm rhs); +// Comparisons with lhs = Double +inline BoundedQuadraticExpression operator==(double lhs, QuadraticTerm rhs); +inline BoundedQuadraticExpression operator==(double lhs, + QuadraticExpression rhs); + //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// // Inline function implementations ///////////////////////////////////////////// @@ -2434,6 +2635,405 @@ QuadraticExpression QuadraticExpression::InnerProduct( return result; } +///////////////////////////////////////////////////////////////////////////////// +// LowerBoundedQuadraticExpression +// UpperBoundedQuadraticExpression +// BoundedQuadraticExpression +//////////////////////////////////////////////////////////////////////////////// + +LowerBoundedQuadraticExpression::LowerBoundedQuadraticExpression( + QuadraticExpression expression, const double lower_bound) + : expression(std::move(expression)), lower_bound(lower_bound) {} +LowerBoundedQuadraticExpression::LowerBoundedQuadraticExpression( + LowerBoundedLinearExpression lb_expression) + : expression(std::move(lb_expression.expression)), + lower_bound(lb_expression.lower_bound) {} + +UpperBoundedQuadraticExpression::UpperBoundedQuadraticExpression( + QuadraticExpression expression, const double upper_bound) + : expression(std::move(expression)), upper_bound(upper_bound) {} +UpperBoundedQuadraticExpression::UpperBoundedQuadraticExpression( + UpperBoundedLinearExpression ub_expression) + : expression(std::move(ub_expression.expression)), + upper_bound(ub_expression.upper_bound) {} + +BoundedQuadraticExpression::BoundedQuadraticExpression( + QuadraticExpression expression, const double lower_bound, + const double upper_bound) + : expression(std::move(expression)), + lower_bound(lower_bound), + upper_bound(upper_bound) {} +BoundedQuadraticExpression::BoundedQuadraticExpression( + internal::VariablesEquality var_equality) + : lower_bound(0), upper_bound(0) { + expression += var_equality.lhs; + expression -= var_equality.rhs; +} +BoundedQuadraticExpression::BoundedQuadraticExpression( + LowerBoundedLinearExpression lb_expression) + : expression(std::move(lb_expression.expression)), + lower_bound(lb_expression.lower_bound), + upper_bound(std::numeric_limits::infinity()) {} +BoundedQuadraticExpression::BoundedQuadraticExpression( + UpperBoundedLinearExpression ub_expression) + : expression(std::move(ub_expression.expression)), + lower_bound(-std::numeric_limits::infinity()), + upper_bound(ub_expression.upper_bound) {} +BoundedQuadraticExpression::BoundedQuadraticExpression( + BoundedLinearExpression bounded_expression) + : expression(std::move(bounded_expression.expression)), + lower_bound(bounded_expression.lower_bound), + upper_bound(bounded_expression.upper_bound) {} +BoundedQuadraticExpression::BoundedQuadraticExpression( + LowerBoundedQuadraticExpression lb_expression) + : expression(std::move(lb_expression.expression)), + lower_bound(lb_expression.lower_bound), + upper_bound(std::numeric_limits::infinity()) {} +BoundedQuadraticExpression::BoundedQuadraticExpression( + UpperBoundedQuadraticExpression ub_expression) + : expression(std::move(ub_expression.expression)), + lower_bound(-std::numeric_limits::infinity()), + upper_bound(ub_expression.upper_bound) {} + +double BoundedQuadraticExpression::lower_bound_minus_offset() const { + return lower_bound - expression.offset(); +} + +double BoundedQuadraticExpression::upper_bound_minus_offset() const { + return upper_bound - expression.offset(); +} + +LowerBoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const double rhs) { + return LowerBoundedQuadraticExpression(std::move(lhs), rhs); +} +LowerBoundedQuadraticExpression operator>=(const QuadraticTerm lhs, + const double rhs) { + return LowerBoundedQuadraticExpression(lhs, rhs); +} +LowerBoundedQuadraticExpression operator<=(const double lhs, + QuadraticExpression rhs) { + return LowerBoundedQuadraticExpression(std::move(rhs), lhs); +} +LowerBoundedQuadraticExpression operator<=(const double lhs, + const QuadraticTerm rhs) { + return LowerBoundedQuadraticExpression(rhs, lhs); +} + +UpperBoundedQuadraticExpression operator>=(const double lhs, + QuadraticExpression rhs) { + return UpperBoundedQuadraticExpression(std::move(rhs), lhs); +} +UpperBoundedQuadraticExpression operator>=(const double lhs, + const QuadraticTerm rhs) { + return UpperBoundedQuadraticExpression(rhs, lhs); +} +UpperBoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const double rhs) { + return UpperBoundedQuadraticExpression(std::move(lhs), rhs); +} +UpperBoundedQuadraticExpression operator<=(const QuadraticTerm lhs, + const double rhs) { + return UpperBoundedQuadraticExpression(lhs, rhs); +} + +BoundedQuadraticExpression operator>=(UpperBoundedQuadraticExpression lhs, + const double rhs) { + return BoundedQuadraticExpression(std::move(lhs.expression), rhs, + lhs.upper_bound); +} +BoundedQuadraticExpression operator>=(const double lhs, + LowerBoundedQuadraticExpression rhs) { + return BoundedQuadraticExpression(std::move(rhs.expression), rhs.lower_bound, + lhs); +} +BoundedQuadraticExpression operator<=(LowerBoundedQuadraticExpression lhs, + const double rhs) { + return BoundedQuadraticExpression(std::move(lhs.expression), lhs.lower_bound, + rhs); +} +BoundedQuadraticExpression operator<=(const double lhs, + UpperBoundedQuadraticExpression rhs) { + return BoundedQuadraticExpression(std::move(rhs.expression), lhs, + rhs.upper_bound); +} + +BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const QuadraticExpression& rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const QuadraticTerm rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const LinearExpression& rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const LinearTerm rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator>=(QuadraticExpression lhs, + const Variable rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const QuadraticExpression& rhs) { + lhs -= rhs; + return BoundedQuadraticExpression( + std::move(lhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const QuadraticTerm rhs) { + lhs -= rhs; + return BoundedQuadraticExpression( + std::move(lhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const LinearExpression& rhs) { + lhs -= rhs; + return BoundedQuadraticExpression( + std::move(lhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const LinearTerm rhs) { + lhs -= rhs; + return BoundedQuadraticExpression( + std::move(lhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(QuadraticExpression lhs, + const Variable rhs) { + lhs -= rhs; + return BoundedQuadraticExpression( + std::move(lhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const QuadraticExpression& rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, 0); +} +BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const QuadraticTerm rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, 0); +} +BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const LinearExpression& rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, 0); +} +BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const LinearTerm rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, 0); +} +BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const Variable rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, 0); +} +BoundedQuadraticExpression operator==(QuadraticExpression lhs, + const double rhs) { + lhs -= rhs; + return BoundedQuadraticExpression(std::move(lhs), 0, 0); +} + +BoundedQuadraticExpression operator>=(const QuadraticTerm lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression( + std::move(rhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(const QuadraticTerm lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression( + rhs - lhs, -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(const QuadraticTerm lhs, + LinearExpression rhs) { + return BoundedQuadraticExpression( + std::move(rhs) - lhs, -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(const QuadraticTerm lhs, + const LinearTerm rhs) { + return BoundedQuadraticExpression( + rhs - lhs, -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(const QuadraticTerm lhs, + const Variable rhs) { + return BoundedQuadraticExpression( + rhs - lhs, -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(const QuadraticTerm lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(const QuadraticTerm lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(const QuadraticTerm lhs, + LinearExpression rhs) { + return BoundedQuadraticExpression(std::move(rhs) - lhs, 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(const QuadraticTerm lhs, + const LinearTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(const QuadraticTerm lhs, + const Variable rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator==(const QuadraticTerm lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, 0); +} +BoundedQuadraticExpression operator==(const QuadraticTerm lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} +BoundedQuadraticExpression operator==(const QuadraticTerm lhs, + LinearExpression rhs) { + return BoundedQuadraticExpression(std::move(rhs) - lhs, 0, 0); +} +BoundedQuadraticExpression operator==(const QuadraticTerm lhs, + const LinearTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} +BoundedQuadraticExpression operator==(const QuadraticTerm lhs, + const Variable rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} +BoundedQuadraticExpression operator==(const QuadraticTerm lhs, + const double rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} + +BoundedQuadraticExpression operator>=(const LinearExpression& lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression( + std::move(rhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(LinearExpression lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression( + rhs - std::move(lhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(const LinearExpression& lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(LinearExpression lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - std::move(lhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator==(const LinearExpression& lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, 0); +} +BoundedQuadraticExpression operator==(LinearExpression lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - std::move(lhs), 0, 0); +} +// LinearTerm -- +BoundedQuadraticExpression operator>=(const LinearTerm lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression( + std::move(rhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(const LinearTerm lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression( + rhs - lhs, -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(const LinearTerm lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(const LinearTerm lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator==(const LinearTerm lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, 0); +} +BoundedQuadraticExpression operator==(const LinearTerm lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} +// Variable -- +BoundedQuadraticExpression operator>=(const Variable lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression( + std::move(rhs), -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator>=(const Variable lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression( + rhs - lhs, -std::numeric_limits::infinity(), 0); +} +BoundedQuadraticExpression operator<=(const Variable lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator<=(const Variable lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, + std::numeric_limits::infinity()); +} +BoundedQuadraticExpression operator==(const Variable lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, 0); +} +BoundedQuadraticExpression operator==(const Variable lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} + +// Double -- +BoundedQuadraticExpression operator==(const double lhs, + QuadraticExpression rhs) { + rhs -= lhs; + return BoundedQuadraticExpression(std::move(rhs), 0, 0); +} +BoundedQuadraticExpression operator==(const double lhs, + const QuadraticTerm rhs) { + return BoundedQuadraticExpression(rhs - lhs, 0, 0); +} + } // namespace math_opt } // namespace operations_research diff --git a/ortools/math_opt/io/proto_converter.cc b/ortools/math_opt/io/proto_converter.cc index 5083ecf0af..923817368b 100644 --- a/ortools/math_opt/io/proto_converter.cc +++ b/ortools/math_opt/io/proto_converter.cc @@ -50,10 +50,6 @@ absl::Status IsSupported(const MPModelProto& model) { return absl::OkStatus(); } -absl::Status IsSupported(const math_opt::ModelProto& model) { - return ValidateModel(model); -} - bool AnyVarNamed(const MPModelProto& model) { for (const MPVariableProto& var : model.variable()) { if (var.name().length() > 0) { @@ -205,7 +201,7 @@ MPModelProtoToMathOptModel(const ::operations_research::MPModelProto& model) { absl::StatusOr<::operations_research::MPModelProto> MathOptModelToMPModelProto( const ::operations_research::math_opt::ModelProto& model) { - RETURN_IF_ERROR(IsSupported(model)); + RETURN_IF_ERROR(ValidateModel(model).status()); const bool vars_have_name = model.variables().names_size() > 0; const bool constraints_have_name = diff --git a/ortools/math_opt/result.proto b/ortools/math_opt/result.proto index 016346f87c..adf5faae48 100644 --- a/ortools/math_opt/result.proto +++ b/ortools/math_opt/result.proto @@ -232,6 +232,8 @@ enum LimitProto { // All information regarding why a call to Solve() terminated. message TerminationProto { + // Additional information in `limit` when value is TERMINATION_REASON_FEASIBLE + // or TERMINATION_REASON_NO_SOLUTION_FOUND, see `limit` for details. TerminationReasonProto reason = 1; // Is LIMIT_UNSPECIFIED unless reason is TERMINATION_REASON_FEASIBLE or diff --git a/ortools/math_opt/solvers/gurobi/g_gurobi.cc b/ortools/math_opt/solvers/gurobi/g_gurobi.cc index a1070055dd..1dbca6ae34 100644 --- a/ortools/math_opt/solvers/gurobi/g_gurobi.cc +++ b/ortools/math_opt/solvers/gurobi/g_gurobi.cc @@ -17,6 +17,7 @@ #include #include +#include "ortools/base/logging.h" #include "ortools/base/cleanup.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -180,7 +181,7 @@ absl::Status Gurobi::AddVars(const absl::Span vbegin, const absl::Span vtype, const absl::Span names) { CHECK_EQ(vind.size(), vval.size()); - const int num_vars = lb.size(); + const int num_vars = static_cast(lb.size()); CHECK_EQ(ub.size(), num_vars); CHECK_EQ(vtype.size(), num_vars); double* c_obj = nullptr; @@ -220,7 +221,7 @@ absl::Status Gurobi::DelVars(const absl::Span ind) { absl::Status Gurobi::AddConstrs(const absl::Span sense, const absl::Span rhs, const absl::Span names) { - const int num_cons = sense.size(); + const int num_cons = static_cast(sense.size()); CHECK_EQ(rhs.size(), num_cons); char** c_names = nullptr; std::vector c_names_data; @@ -247,7 +248,7 @@ absl::Status Gurobi::DelConstrs(const absl::Span ind) { absl::Status Gurobi::AddQpTerms(const absl::Span qrow, const absl::Span qcol, const absl::Span qval) { - const int numqnz = qrow.size(); + const int numqnz = static_cast(qrow.size()); CHECK_EQ(qcol.size(), numqnz); CHECK_EQ(qval.size(), numqnz); return ToStatus(GRBaddqpterms( @@ -257,10 +258,43 @@ absl::Status Gurobi::AddQpTerms(const absl::Span qrow, absl::Status Gurobi::DelQ() { return ToStatus(GRBdelq(gurobi_model_)); } +absl::Status Gurobi::AddQConstr(const absl::Span lind, + const absl::Span lval, + const absl::Span qrow, + const absl::Span qcol, + const absl::Span qval, + const char sense, const double rhs, + const std::string& name) { + const int numlnz = static_cast(lind.size()); + CHECK_EQ(lval.size(), numlnz); + + const int numqlnz = static_cast(qrow.size()); + CHECK_EQ(qcol.size(), numqlnz); + CHECK_EQ(qval.size(), numqlnz); + + return ToStatus(GRBaddqconstr( + /*model=*/gurobi_model_, + /*numlnz=*/numlnz, + /*lind=*/const_cast(lind.data()), + /*lval=*/const_cast(lval.data()), + /*numqlnz=*/numqlnz, + /*qrow=*/const_cast(qrow.data()), + /*qcol=*/const_cast(qcol.data()), + /*qval=*/const_cast(qval.data()), + /*sense=*/sense, + /*rhs=*/rhs, + /*constrname=*/const_cast(name.c_str()))); +} + +absl::Status Gurobi::DelQConstrs(const absl::Span ind) { + return ToStatus(GRBdelqconstrs(gurobi_model_, static_cast(ind.size()), + const_cast(ind.data()))); +} + absl::Status Gurobi::ChgCoeffs(const absl::Span cind, const absl::Span vind, const absl::Span val) { - const int num_changes = cind.size(); + const int num_changes = static_cast(cind.size()); CHECK_EQ(vind.size(), num_changes); CHECK_EQ(val.size(), num_changes); return ToStatus(GRBchgcoeffs( @@ -451,7 +485,7 @@ absl::StatusOr> Gurobi::GetCharAttrArray( absl::Status Gurobi::SetIntAttrList(const char* const name, const absl::Span ind, const absl::Span new_values) { - const int len = ind.size(); + const int len = static_cast(ind.size()); CHECK_EQ(new_values.size(), len); return ToStatus(GRBsetintattrlist(gurobi_model_, name, len, const_cast(ind.data()), @@ -461,7 +495,7 @@ absl::Status Gurobi::SetIntAttrList(const char* const name, absl::Status Gurobi::SetDoubleAttrList( const char* const name, const absl::Span ind, const absl::Span new_values) { - const int len = ind.size(); + const int len = static_cast(ind.size()); CHECK_EQ(new_values.size(), len); return ToStatus(GRBsetdblattrlist(gurobi_model_, name, len, const_cast(ind.data()), @@ -471,7 +505,7 @@ absl::Status Gurobi::SetDoubleAttrList( absl::Status Gurobi::SetCharAttrList(const char* const name, const absl::Span ind, const absl::Span new_values) { - const int len = ind.size(); + const int len = static_cast(ind.size()); CHECK_EQ(new_values.size(), len); return ToStatus(GRBsetcharattrlist(gurobi_model_, name, len, const_cast(ind.data()), @@ -559,7 +593,7 @@ absl::StatusOr Gurobi::CallbackContext::CbGetMessage() const { absl::Status Gurobi::CallbackContext::CbCut( const absl::Span cutind, const absl::Span cutval, const char cutsense, const double cutrhs) const { - const int cut_len = cutind.size(); + const int cut_len = static_cast(cutind.size()); CHECK_EQ(cutval.size(), cut_len); return gurobi_->ToStatus( GRBcbcut(cb_data_, cut_len, const_cast(cutind.data()), @@ -569,7 +603,7 @@ absl::Status Gurobi::CallbackContext::CbCut( absl::Status Gurobi::CallbackContext::CbLazy( const absl::Span lazyind, const absl::Span lazyval, const char lazysense, const double lazyrhs) const { - const int lazy_len = lazyind.size(); + const int lazy_len = static_cast(lazyind.size()); CHECK_EQ(lazyval.size(), lazy_len); return gurobi_->ToStatus( GRBcblazy(cb_data_, lazy_len, const_cast(lazyind.data()), diff --git a/ortools/math_opt/solvers/gurobi/g_gurobi.h b/ortools/math_opt/solvers/gurobi/g_gurobi.h index 7bc1941230..8e8e9c1ed6 100644 --- a/ortools/math_opt/solvers/gurobi/g_gurobi.h +++ b/ortools/math_opt/solvers/gurobi/g_gurobi.h @@ -331,6 +331,23 @@ class Gurobi { // Deletes all quadratic objective coefficients. absl::Status DelQ(); + // Calls GRBaddqconstr(). + // + // Requirements: + // * lind and lval must be equal length. + // * qrow, qcol, and qval must be equal length. + absl::Status AddQConstr(absl::Span lind, + absl::Span lval, + absl::Span qrow, + absl::Span qcol, + absl::Span qval, char sense, double rhs, + const std::string& name); + + // Calls GRBdelqconstrs(). + // + // Deletes the specified quadratic constraints. + absl::Status DelQConstrs(const absl::Span ind); + ////////////////////////////////////////////////////////////////////////////// // Linear constraint matrix queries. ////////////////////////////////////////////////////////////////////////////// diff --git a/ortools/math_opt/tools/mathopt_solve_main.cc b/ortools/math_opt/tools/mathopt_solve_main.cc index ee30e4f5a4..cf9e05d9b8 100644 --- a/ortools/math_opt/tools/mathopt_solve_main.cc +++ b/ortools/math_opt/tools/mathopt_solve_main.cc @@ -36,6 +36,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "google/protobuf/text_format.h" #include "ortools/base/file.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" @@ -117,6 +118,9 @@ ABSL_FLAG(operations_research::math_opt::SolverType, solver_type, operations_research::math_opt::AllSolversRegistry::Instance() ->RegisteredSolvers(), ", ", SolverTypeProtoFormatter()))); +ABSL_FLAG(std::string, solve_parameters, "", + "SolveParameters in text-proto format. Note that the time limit is " + "overridden by the --time_limit flag."); ABSL_FLAG(bool, solver_logs, false, "use a message callback to print the solver convergence logs"); ABSL_FLAG(absl::Duration, time_limit, absl::InfiniteDuration(), @@ -306,9 +310,16 @@ absl::Status RunSolver() { } // Solve the problem. + SolveParametersProto solve_parameters_proto; + QCHECK(google::protobuf::TextFormat::ParseFromString( + absl::GetFlag(FLAGS_solve_parameters), &solve_parameters_proto)) + << "Unable to parse --solve_parameters"; + ASSIGN_OR_RETURN(const SolveParameters solve_parameters, + SolveParameters::FromProto(solve_parameters_proto)); SolveArguments solve_args = { - .parameters = {.time_limit = absl::GetFlag(FLAGS_time_limit)}, + .parameters = solve_parameters, }; + solve_args.parameters.time_limit = absl::GetFlag(FLAGS_time_limit); if (absl::GetFlag(FLAGS_solver_logs)) { solve_args.message_callback = PrinterMessageCallback(std::cout, "logs| "); } diff --git a/ortools/math_opt/validators/BUILD.bazel b/ortools/math_opt/validators/BUILD.bazel index 5aeca87066..bff122f746 100644 --- a/ortools/math_opt/validators/BUILD.bazel +++ b/ortools/math_opt/validators/BUILD.bazel @@ -53,31 +53,12 @@ cc_library( ], ) -cc_library( - name = "name_validator", - srcs = ["name_validator.cc"], - hdrs = ["name_validator.h"], - deps = [ - ":sparse_vector_validator", - "//ortools/base", - "//ortools/base:map_util", - "//ortools/base:status_macros", - "//ortools/math_opt/core:model_summary", - "//ortools/math_opt/core:sparse_vector_view", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - cc_library( name = "model_validator", srcs = ["model_validator.cc"], hdrs = ["model_validator.h"], deps = [ ":ids_validator", - ":name_validator", ":scalar_validator", ":sparse_matrix_validator", ":sparse_vector_validator", diff --git a/ortools/math_opt/validators/ids_validator.cc b/ortools/math_opt/validators/ids_validator.cc index 92dd9c9a98..339ac5362c 100644 --- a/ortools/math_opt/validators/ids_validator.cc +++ b/ortools/math_opt/validators/ids_validator.cc @@ -30,89 +30,7 @@ #include "ortools/base/status_macros.h" #include "ortools/math_opt/core/model_summary.h" -namespace operations_research { -namespace math_opt { - -namespace { -absl::Status CheckSortedIdsSubsetWithIndexOffset( - const absl::Span ids, - const absl::Span universe, const int64_t offset) { - int id_index = 0; - int universe_index = 0; - // NOTE(user): in the common case where ids and/or universe is consecutive, - // we can avoid iterating though the list and do interval based checks. - while (id_index < ids.size() && universe_index < universe.size()) { - if (universe[universe_index] < ids[id_index]) { - ++universe_index; - } else if (universe[universe_index] == ids[id_index]) { - ++id_index; - } else { - break; - } - } - if (id_index < ids.size()) { - return util::InvalidArgumentErrorBuilder() - << "Bad id: " << ids[id_index] << " (at index: " << id_index + offset - << ") found"; - } - return absl::OkStatus(); -} - -// Given a sorted, strictly increasing set of ids, provides `contains()` to -// check if another id is in the set in O(1) time. -// -// Implementation note: when ids are consecutive, they are stored as a single -// interval [lb, ub), otherwise they are stored as a hash table of integers. -class FastIdCheck { - public: - // ids must be sorted with unique strictly increasing entries. - explicit FastIdCheck(const absl::Span ids) { - if (ids.empty()) { - interval_mode_ = true; - } else if (ids.size() == ids.back() + 1 - ids.front()) { - interval_mode_ = true; - interval_lb_ = ids.front(); - interval_ub_ = ids.back() + 1; - } else { - ids_ = absl::flat_hash_set(ids.begin(), ids.end()); - } - } - bool contains(int64_t id) const { - if (interval_mode_) { - return id >= interval_lb_ && id < interval_ub_; - } else { - return ids_.contains(id); - } - } - - private: - bool interval_mode_ = false; - int64_t interval_lb_ = 0; - int64_t interval_ub_ = 0; - absl::flat_hash_set ids_; -}; - -// Checks that the elements of ids and bad_list have no overlap. -// -// Assumed: ids and bad_list are sorted in increasing order, repeats allowed. -absl::Status CheckSortedIdsNotBad(const absl::Span ids, - const absl::Span bad_list) { - int id_index = 0; - int bad_index = 0; - while (id_index < ids.size() && bad_index < bad_list.size()) { - if (bad_list[bad_index] < ids[id_index]) { - ++bad_index; - } else if (bad_list[bad_index] > ids[id_index]) { - ++id_index; - } else { - return util::InvalidArgumentErrorBuilder() - << "Bad id: " << ids[id_index] << " (at index: " << id_index - << ") found"; - } - } - return absl::OkStatus(); -} -} // namespace +namespace operations_research::math_opt { absl::Status CheckIdsRangeAndStrictlyIncreasing(absl::Span ids) { int64_t previous{-1}; @@ -133,95 +51,17 @@ absl::Status CheckIdsRangeAndStrictlyIncreasing(absl::Span ids) { return absl::OkStatus(); } -absl::Status CheckSortedIdsSubset(const absl::Span ids, - const absl::Span universe) { - RETURN_IF_ERROR(CheckSortedIdsSubsetWithIndexOffset(ids, universe, 0)); - return absl::OkStatus(); -} - -absl::Status CheckUnsortedIdsSubset(const absl::Span ids, - const absl::Span universe) { - if (ids.empty()) { - return absl::OkStatus(); - } - const FastIdCheck id_check(universe); - for (int i = 0; i < ids.size(); ++i) { - if (!id_check.contains(ids[i])) { +absl::Status CheckIdsSubset(absl::Span ids, + const IdNameBiMap& universe, + std::optional upper_bound) { + for (const int64_t id : ids) { + if (upper_bound.has_value() && id >= *upper_bound) { return util::InvalidArgumentErrorBuilder() - << "Bad id: " << ids[i] << " (at index: " << i << ") not found"; + << "id " << id + << " should be less than upper bound: " << *upper_bound; } - } - return absl::OkStatus(); -} - -absl::Status IdUpdateValidator::IsValid() const { - for (int i = 0; i < deleted_ids_.size(); ++i) { - const int64_t deleted_id = deleted_ids_[i]; - if (!old_ids_.HasId(deleted_id)) { - return util::InvalidArgumentErrorBuilder() - << "Tried to delete id: " << deleted_id << " (at index: " << i - << ") but it was not present"; - } - } - if (!new_ids_.empty() && new_ids_.front() < old_ids_.next_free_id()) { - return util::InvalidArgumentErrorBuilder() - << "All new ids should be greater or equal to the first unused id: " - << old_ids_.next_free_id() - << " but the first new id was: " << new_ids_.front(); - } - return absl::OkStatus(); -} - -absl::Status IdUpdateValidator::CheckSortedIdsSubsetOfNotDeleted( - const absl::Span ids) const { - RETURN_IF_ERROR(CheckSortedIdsNotBad(ids, deleted_ids_)) << " was deleted"; - for (int i = 0; i < ids.size(); ++i) { - if (!old_ids_.HasId(ids[i])) { - return util::InvalidArgumentErrorBuilder() - << "Bad id: " << ids[i] << " (at index: " << i << ") not found"; - } - } - return absl::OkStatus(); -} - -absl::Status IdUpdateValidator::CheckSortedIdsSubsetOfFinal( - const absl::Span ids) const { - // Implementation: - // * Partition ids into "old" and "new" - // * Check that the old ids are in old_ids_ but not deleted_ids_. - // * Check that the new ids are in new_ids_. - size_t split_point = ids.size(); - if (!new_ids_.empty()) { - split_point = std::distance( - ids.begin(), std::lower_bound(ids.begin(), ids.end(), new_ids_[0])); - } - RETURN_IF_ERROR( - CheckSortedIdsSubsetOfNotDeleted(ids.subspan(0, split_point))); - RETURN_IF_ERROR(CheckSortedIdsSubsetWithIndexOffset(ids.subspan(split_point), - new_ids_, split_point)); - return absl::OkStatus(); -} - -absl::Status IdUpdateValidator::CheckIdsSubsetOfFinal( - const absl::Span ids) const { - if (ids.empty()) { - return absl::OkStatus(); - } - const FastIdCheck deleted_fast(deleted_ids_); - const FastIdCheck new_fast(new_ids_); - for (int i = 0; i < ids.size(); ++i) { - const int64_t id = ids[i]; - if (!new_ids_.empty() && id >= new_ids_[0]) { - if (!new_fast.contains(id)) { - return util::InvalidArgumentErrorBuilder() - << "Bad id: " << id << " (at index: " << i << ") not found"; - } - } else if (!old_ids_.HasId(id)) { - return util::InvalidArgumentErrorBuilder() - << "Bad id: " << id << " (at index: " << i << ") not found"; - } else if (deleted_fast.contains(id)) { - return util::InvalidArgumentErrorBuilder() - << "Bad id: " << id << " (at index: " << i << ") was deleted"; + if (!universe.HasId(id)) { + return util::InvalidArgumentErrorBuilder() << "id " << id << " not found"; } } return absl::OkStatus(); @@ -256,5 +96,4 @@ absl::Status CheckIdsIdentical(absl::Span first_ids, return absl::OkStatus(); } -} // namespace math_opt -} // namespace operations_research +} // namespace operations_research::math_opt diff --git a/ortools/math_opt/validators/ids_validator.h b/ortools/math_opt/validators/ids_validator.h index 40baeda9ff..fb3ef51980 100644 --- a/ortools/math_opt/validators/ids_validator.h +++ b/ortools/math_opt/validators/ids_validator.h @@ -24,22 +24,19 @@ namespace operations_research { namespace math_opt { -// Checks that the input ids are in [0, max(int64_t)) range and that their are +// Checks that the input ids are in [0, max(int64_t)) range and that they are // strictly increasing. absl::Status CheckIdsRangeAndStrictlyIncreasing(absl::Span ids); -// Checks that the elements of ids are a subset of universe. +// Checks that the elements of ids are a subset of universe. Elements of ids +// do not need to be sorted or distinct. If upper_bound is set, elements must be +// strictly less than upper_bound. // -// Assumed: ids and universe are sorted in increasing order, repeats allowed. -absl::Status CheckSortedIdsSubset(const absl::Span ids, - const absl::Span universe); - -// Checks that the elements of ids are a subset of universe. -// -// Assumed: universe are sorted in strictly increasing order (no repeats). No -// assumptions on ids. -absl::Status CheckUnsortedIdsSubset(const absl::Span ids, - const absl::Span universe); +// TODO(b/232526223): try merge this with the CheckIdsSubset overload below, or +// at least have one call the other. +absl::Status CheckIdsSubset(absl::Span ids, + const IdNameBiMap& universe, + std::optional upper_bound = std::nullopt); // Checks that the elements of ids are a subset of universe. Elements of ids // do not need to be sorted or distinct. @@ -54,50 +51,6 @@ absl::Status CheckIdsIdentical(absl::Span first_ids, absl::string_view first_description, absl::string_view second_description); -// Provides a unified view of the id sets: -// * NOT_DELETED = old - deleted -// * FINAL = old - deleted + new -// so users can validate if a list of ids (sorted or unsorted) is a subset of -// either of the sets above. -// -// Implementation note: this class does not allocate by default, but some -// functions will allocate at most O(#deleted + #new). -class IdUpdateValidator { - public: - // deleted_ids and new_ids must be sorted with unique strictly increasing - // entries. - IdUpdateValidator(const IdNameBiMap& old_ids, - const absl::Span deleted_ids, - const absl::Span new_ids) - : old_ids_(old_ids), deleted_ids_(deleted_ids), new_ids_(new_ids) {} - - // Returns true if the sets of ids passed to the constructor are valid. - absl::Status IsValid() const; - - // Checks that ids is a subset of NOT_DELETED = old_ids_ - deleted_ids_. - // - // ids must be sorted in increasing order (repeats are allowed). - absl::Status CheckSortedIdsSubsetOfNotDeleted( - const absl::Span ids) const; - - // Checks that ids is a subset of FINAL = old_ids_ - deleted_ids_ + new_ids_. - // - // ids must be sorted in increasing order (repeats are allowed). - absl::Status CheckSortedIdsSubsetOfFinal( - const absl::Span ids) const; - - // Checks that ids is a subset of FINAL = old_ids_ - deleted_ids_ + new_ids_. - // - // If ids is sorted, prefer CheckSortedIdsSubsetOfFinal. - absl::Status CheckIdsSubsetOfFinal(const absl::Span ids) const; - - private: - // NOT OWNED - const IdNameBiMap& old_ids_; - const absl::Span deleted_ids_; - const absl::Span new_ids_; -}; - } // namespace math_opt } // namespace operations_research diff --git a/ortools/math_opt/validators/model_validator.cc b/ortools/math_opt/validators/model_validator.cc index 8baaac71ba..d7541c2efc 100644 --- a/ortools/math_opt/validators/model_validator.cc +++ b/ortools/math_opt/validators/model_validator.cc @@ -21,7 +21,6 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "ortools/base/integral_types.h" #include "ortools/base/status_macros.h" #include "ortools/math_opt/core/model_summary.h" #include "ortools/math_opt/core/sparse_vector_view.h" @@ -29,7 +28,6 @@ #include "ortools/math_opt/model_update.pb.h" #include "ortools/math_opt/sparse_containers.pb.h" #include "ortools/math_opt/validators/ids_validator.h" -#include "ortools/math_opt/validators/name_validator.h" #include "ortools/math_opt/validators/scalar_validator.h" #include "ortools/math_opt/validators/sparse_matrix_validator.h" #include "ortools/math_opt/validators/sparse_vector_validator.h" @@ -42,8 +40,7 @@ namespace { // Submessages //////////////////////////////////////////////////////////////////////////////// -absl::Status VariablesValid(const VariablesProto& variables, - const bool check_names) { +absl::Status VariablesValid(const VariablesProto& variables) { RETURN_IF_ERROR(CheckIdsRangeAndStrictlyIncreasing(variables.ids())) << "Bad variable ids"; RETURN_IF_ERROR( @@ -54,13 +51,12 @@ absl::Status VariablesValid(const VariablesProto& variables, {.allow_negative_infinity = false}, "upper_bounds")); RETURN_IF_ERROR( CheckValues(MakeView(variables.ids(), variables.integers()), "integers")); - RETURN_IF_ERROR(CheckNameVector(MakeView(variables.ids(), variables.names()), - check_names)); return absl::OkStatus(); } -absl::Status VariableUpdatesValid( - const VariableUpdatesProto& variable_updates) { +absl::Status VariableUpdatesValid(const VariableUpdatesProto& variable_updates, + const IdNameBiMap& variable_ids, + const int64_t old_var_id_ub) { RETURN_IF_ERROR(CheckIdsAndValues(MakeView(variable_updates.lower_bounds()), {.allow_positive_infinity = false})) << "Bad lower bounds"; @@ -69,26 +65,20 @@ absl::Status VariableUpdatesValid( << "Bad upper bounds"; RETURN_IF_ERROR(CheckIdsAndValues(MakeView(variable_updates.integers()))) << "Bad integers"; - return absl::OkStatus(); -} - -absl::Status VariableUpdatesValidForState( - const VariableUpdatesProto& variable_updates, - const IdUpdateValidator& id_validator) { - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfNotDeleted( - variable_updates.lower_bounds().ids())) + RETURN_IF_ERROR(CheckIdsSubset(variable_updates.lower_bounds().ids(), + variable_ids, old_var_id_ub)) << "lower bound update on invalid variable id"; - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfNotDeleted( - variable_updates.upper_bounds().ids())) + RETURN_IF_ERROR(CheckIdsSubset(variable_updates.upper_bounds().ids(), + variable_ids, old_var_id_ub)) << "upper bound update on invalid variable id"; - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfNotDeleted( - variable_updates.integers().ids())) + RETURN_IF_ERROR(CheckIdsSubset(variable_updates.integers().ids(), + variable_ids, old_var_id_ub)) << "integer update on invalid variable id"; return absl::OkStatus(); } absl::Status ObjectiveValid(const ObjectiveProto& objective, - absl::Span variable_ids) { + const IdNameBiMap& variable_ids) { // 1. Validate offset RETURN_IF_ERROR(CheckScalarNoNanNoInf(objective.offset())) << "Objective offset invalid"; @@ -98,7 +88,7 @@ absl::Status ObjectiveValid(const ObjectiveProto& objective, linear_coefficients, {.allow_positive_infinity = false, .allow_negative_infinity = false})) << "Linear objective coefficients bad"; - RETURN_IF_ERROR(CheckSortedIdsSubset(linear_coefficients.ids(), variable_ids)) + RETURN_IF_ERROR(CheckIdsSubset(linear_coefficients.ids(), variable_ids)) << "Objective.linear_coefficients.ids not found in Variables.ids"; // 3. Validate quadratic terms RETURN_IF_ERROR(SparseMatrixValid(objective.quadratic_coefficients(), @@ -112,7 +102,8 @@ absl::Status ObjectiveValid(const ObjectiveProto& objective, // NOTE: This method does not check requirements on the IDs absl::Status ObjectiveUpdatesValid( - const ObjectiveUpdatesProto& objective_updates) { + const ObjectiveUpdatesProto& objective_updates, + const IdNameBiMap& variable_ids) { // 1. Validate offset RETURN_IF_ERROR(CheckScalarNoNanNoInf(objective_updates.offset_update())) << "Offset update invalid"; @@ -120,31 +111,22 @@ absl::Status ObjectiveUpdatesValid( RETURN_IF_ERROR(CheckIdsAndValues( MakeView(objective_updates.linear_coefficients()), {.allow_positive_infinity = false, .allow_negative_infinity = false})) - << "Linear objective coefficients bad"; + << "Linear objective coefficients invalid"; // 3. Validate quadratic terms RETURN_IF_ERROR(SparseMatrixValid(objective_updates.quadratic_coefficients(), /*enforce_upper_triangular=*/true)) << "Objective.quadratic_coefficients invalid"; - return absl::OkStatus(); -} - -absl::Status ObjectiveUpdatesValidForModel( - const ObjectiveUpdatesProto& objective_updates, - const IdUpdateValidator& id_validator) { - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfFinal( - objective_updates.linear_coefficients().ids())) + RETURN_IF_ERROR(CheckIdsSubset(objective_updates.linear_coefficients().ids(), + variable_ids)) << "Linear coefficients ids not found in variable ids"; - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfFinal( - objective_updates.quadratic_coefficients().row_ids())) - << "Quadratic coefficient ids bad"; - RETURN_IF_ERROR(id_validator.CheckIdsSubsetOfFinal( - objective_updates.quadratic_coefficients().column_ids())) - << "Quadratic coefficient ids bad"; + RETURN_IF_ERROR(SparseMatrixIdsAreKnown( + objective_updates.quadratic_coefficients(), variable_ids, variable_ids)) + << "quadratic_coefficients invalid"; return absl::OkStatus(); } absl::Status LinearConstraintsValid( - const LinearConstraintsProto& linear_constraints, const bool check_names) { + const LinearConstraintsProto& linear_constraints) { RETURN_IF_ERROR(CheckIdsRangeAndStrictlyIncreasing(linear_constraints.ids())) << "Bad linear constraint ids"; RETURN_IF_ERROR(CheckValues( @@ -153,14 +135,12 @@ absl::Status LinearConstraintsValid( RETURN_IF_ERROR(CheckValues( MakeView(linear_constraints.ids(), linear_constraints.upper_bounds()), {.allow_negative_infinity = false}, "upper_bounds")); - RETURN_IF_ERROR(CheckNameVector( - MakeView(linear_constraints.ids(), linear_constraints.names()), - check_names)); return absl::OkStatus(); } absl::Status LinearConstraintUpdatesValid( - const LinearConstraintUpdatesProto& linear_constraint_updates) { + const LinearConstraintUpdatesProto& linear_constraint_updates, + const IdNameBiMap& linear_constraint_ids, const int64_t old_lin_con_id_ub) { RETURN_IF_ERROR( CheckIdsAndValues(MakeView(linear_constraint_updates.lower_bounds()), {.allow_positive_infinity = false})) @@ -169,30 +149,21 @@ absl::Status LinearConstraintUpdatesValid( CheckIdsAndValues(MakeView(linear_constraint_updates.upper_bounds()), {.allow_negative_infinity = false})) << "Bad upper bounds"; - return absl::OkStatus(); -} - -absl::Status LinearConstraintUpdatesValidForState( - const LinearConstraintUpdatesProto& linear_constraint_updates, - const IdUpdateValidator& id_validator) { - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfNotDeleted( - linear_constraint_updates.lower_bounds().ids())) + RETURN_IF_ERROR(CheckIdsSubset(linear_constraint_updates.lower_bounds().ids(), + linear_constraint_ids, old_lin_con_id_ub)) << "lower bound update on invalid linear constraint id"; - RETURN_IF_ERROR(id_validator.CheckSortedIdsSubsetOfNotDeleted( - linear_constraint_updates.upper_bounds().ids())) + RETURN_IF_ERROR(CheckIdsSubset(linear_constraint_updates.upper_bounds().ids(), + linear_constraint_ids, old_lin_con_id_ub)) << "upper bound update on invalid linear constraint id"; return absl::OkStatus(); } absl::Status LinearConstraintMatrixIdsValidForUpdate( const SparseDoubleMatrixProto& matrix, - const IdUpdateValidator& linear_constraint_id_validator, - const IdUpdateValidator& variable_id_validator) { - RETURN_IF_ERROR(linear_constraint_id_validator.CheckSortedIdsSubsetOfFinal( - matrix.row_ids())) + const IdNameBiMap& linear_constraint_ids, const IdNameBiMap& variable_ids) { + RETURN_IF_ERROR(CheckIdsSubset(matrix.row_ids(), linear_constraint_ids)) << "Unknown linear_constraint_id"; - RETURN_IF_ERROR( - variable_id_validator.CheckIdsSubsetOfFinal(matrix.column_ids())) + RETURN_IF_ERROR(CheckIdsSubset(matrix.column_ids(), variable_ids)) << "Unknown variable_id"; return absl::OkStatus(); } @@ -203,21 +174,23 @@ absl::Status LinearConstraintMatrixIdsValidForUpdate( // Model // ///////////////////////////////////////////////////////////////////////////// -absl::Status ValidateModel(const ModelProto& model, const bool check_names) { - RETURN_IF_ERROR(VariablesValid(model.variables(), check_names)) +absl::StatusOr ValidateModel(const ModelProto& model, + const bool check_names) { + ASSIGN_OR_RETURN(const auto model_summary, + ModelSummary::Create(model, check_names)); + RETURN_IF_ERROR(VariablesValid(model.variables())) << "Model.variables are invalid."; - RETURN_IF_ERROR(ObjectiveValid(model.objective(), model.variables().ids())) + RETURN_IF_ERROR(ObjectiveValid(model.objective(), model_summary.variables)) << "Model.objective is invalid"; - RETURN_IF_ERROR( - LinearConstraintsValid(model.linear_constraints(), check_names)) + RETURN_IF_ERROR(LinearConstraintsValid(model.linear_constraints())) << "Model.linear_constraints are invalid"; RETURN_IF_ERROR(SparseMatrixValid(model.linear_constraint_matrix())) << "Model.linear_constraint_matrix invalid"; RETURN_IF_ERROR(SparseMatrixIdsAreKnown(model.linear_constraint_matrix(), - model.linear_constraints().ids(), - model.variables().ids())) + model_summary.linear_constraints, + model_summary.variables)) << "Model.linear_constraint_matrix ids are inconsistent"; - return absl::OkStatus(); + return model_summary; } //////////////////////////////////////////////////////////////////////////////// @@ -225,77 +198,38 @@ absl::Status ValidateModel(const ModelProto& model, const bool check_names) { //////////////////////////////////////////////////////////////////////////////// absl::Status ValidateModelUpdate(const ModelUpdateProto& model_update, - const bool check_names) { - RETURN_IF_ERROR(CheckIdsRangeAndStrictlyIncreasing( - model_update.deleted_linear_constraint_ids())) - << "ModelUpdateProto.deleted_linear_constraint_ids invalid"; - RETURN_IF_ERROR( - CheckIdsRangeAndStrictlyIncreasing(model_update.deleted_variable_ids())) - << "ModelUpdateProto.deleted_variable_ids invalid"; - RETURN_IF_ERROR(VariableUpdatesValid(model_update.variable_updates())) + ModelSummary& model_summary) { + RETURN_IF_ERROR(model_summary.Update(model_update)); + const int64_t old_var_id_ub = model_update.new_variables().ids_size() > 0 + ? model_update.new_variables().ids(0) + : model_summary.variables.next_free_id(); + const int64_t old_lin_con_id_ub = + model_update.new_linear_constraints().ids_size() > 0 + ? model_update.new_linear_constraints().ids(0) + : model_summary.linear_constraints.next_free_id(); + RETURN_IF_ERROR(VariableUpdatesValid(model_update.variable_updates(), + model_summary.variables, old_var_id_ub)) << "ModelUpdateProto.variable_updates invalid"; - RETURN_IF_ERROR( - LinearConstraintUpdatesValid(model_update.linear_constraint_updates())) + RETURN_IF_ERROR(LinearConstraintUpdatesValid( + model_update.linear_constraint_updates(), + model_summary.linear_constraints, old_lin_con_id_ub)) << "ModelUpdateProto.linear_constraint_updates invalid"; - RETURN_IF_ERROR(VariablesValid(model_update.new_variables(), check_names)) + RETURN_IF_ERROR(VariablesValid(model_update.new_variables())) << "ModelUpdateProto.new_variables invalid"; - RETURN_IF_ERROR(LinearConstraintsValid(model_update.new_linear_constraints(), - check_names)) + RETURN_IF_ERROR(LinearConstraintsValid(model_update.new_linear_constraints())) << "ModelUpdateProto.new_linear_constraints invalid"; - RETURN_IF_ERROR(ObjectiveUpdatesValid(model_update.objective_updates())) + RETURN_IF_ERROR(ObjectiveUpdatesValid(model_update.objective_updates(), + model_summary.variables)) << "ModelUpdateProto.objective_update invalid"; RETURN_IF_ERROR( SparseMatrixValid(model_update.linear_constraint_matrix_updates())) << "Model.linear_constraint_matrix_updates invalid"; - return absl::OkStatus(); -} -absl::Status ValidateModelUpdateAndSummary(const ModelUpdateProto& model_update, - const ModelSummary& model_summary, - const bool check_names) { - RETURN_IF_ERROR(ValidateModelUpdate(model_update)); - const IdUpdateValidator variable_id_validator( - model_summary.variables, model_update.deleted_variable_ids(), - model_update.new_variables().ids()); - RETURN_IF_ERROR(variable_id_validator.IsValid()) - << "Invalid new or deleted variable id"; - const IdUpdateValidator linear_constraint_id_validator( - model_summary.linear_constraints, - model_update.deleted_linear_constraint_ids(), - model_update.new_linear_constraints().ids()); - RETURN_IF_ERROR(linear_constraint_id_validator.IsValid()) - << "Invalid new or deleted linear constraint id"; - - RETURN_IF_ERROR(VariableUpdatesValidForState(model_update.variable_updates(), - variable_id_validator)) - << "Invalid variable update"; - - RETURN_IF_ERROR(LinearConstraintUpdatesValidForState( - model_update.linear_constraint_updates(), linear_constraint_id_validator)) - << "Invalid linear constraint update"; - - RETURN_IF_ERROR(ObjectiveUpdatesValidForModel( - model_update.objective_updates(), variable_id_validator)) - << "Invalid objective update"; RETURN_IF_ERROR(LinearConstraintMatrixIdsValidForUpdate( model_update.linear_constraint_matrix_updates(), - linear_constraint_id_validator, variable_id_validator)) - << "Invalid linear constraint matrix update"; - if (check_names && !model_update.new_variables().names().empty()) { - RETURN_IF_ERROR( - CheckNewNames(model_summary.variables, - MakeView(model_update.new_variables().ids(), - model_update.new_variables().names()))) - << "Bad new variable names"; - } + model_summary.linear_constraints, model_summary.variables)) + << "invalid linear constraint matrix update"; - if (check_names && !model_update.new_linear_constraints().names().empty()) { - RETURN_IF_ERROR( - CheckNewNames(model_summary.linear_constraints, - MakeView(model_update.new_linear_constraints().ids(), - model_update.new_linear_constraints().names()))) - << "Bad new linear constraint names"; - } return absl::OkStatus(); } diff --git a/ortools/math_opt/validators/model_validator.h b/ortools/math_opt/validators/model_validator.h index e233e574dc..eb481d0eb3 100644 --- a/ortools/math_opt/validators/model_validator.h +++ b/ortools/math_opt/validators/model_validator.h @@ -24,33 +24,19 @@ namespace math_opt { // Runs in O(size of model) and allocates O(#variables + #linear constraints) // memory. -absl::Status ValidateModel(const ModelProto& model, bool check_names = true); +absl::StatusOr ValidateModel(const ModelProto& model, + bool check_names = true); -// Validates the update as-is; without any knowledge of the model or previous -// updates. Some tests of the validity of ids are also not done. -// -// Performance: runs in O(size of update). -// -// See ValidateModelUpdateAndSummary() for a version that does a full -// validation taking into account the model and previous updates. -absl::Status ValidateModelUpdate(const ModelUpdateProto& model_update, - bool check_names = true); - -// Validates the update taking into account the model and previous updates (via -// the provided summary). -// -// Note that this function uses model_summary.(variables|linear_constraints)'s -// next_free_id() to test that new variables/constraints ids are valid. +// Validates the update is consistent both internally and with current model (as +// given by model_summary) and updates the model_summary. // // Performance: runs in O(size of update), allocates at most // O(#new or deleted variables + #new or deleted linear constraints). // -// It internally calls ValidateModelUpdate() which validates all predicates that -// can be validated without knowledge of the initial model and the previous -// updates. -absl::Status ValidateModelUpdateAndSummary(const ModelUpdateProto& model_update, - const ModelSummary& model_summary, - bool check_names = true); +// If the function returns an error, no guarantees are made on the state of +// model_summary. +absl::Status ValidateModelUpdate(const ModelUpdateProto& model_update, + ModelSummary& model_summary); } // namespace math_opt } // namespace operations_research diff --git a/ortools/math_opt/validators/name_validator.cc b/ortools/math_opt/validators/name_validator.cc deleted file mode 100644 index 98a8daca47..0000000000 --- a/ortools/math_opt/validators/name_validator.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright 2010-2021 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ortools/math_opt/validators/name_validator.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "ortools/base/integral_types.h" -#include "ortools/base/map_util.h" -#include "ortools/base/status_macros.h" -#include "ortools/math_opt/core/model_summary.h" -#include "ortools/math_opt/core/sparse_vector_view.h" -#include "ortools/math_opt/validators/sparse_vector_validator.h" - -namespace operations_research { -namespace math_opt { - -absl::Status CheckNameVector( - const SparseVectorView& name_vector, - const bool check_unique) { - if (name_vector.values().empty()) { - // Names are optional. - return absl::OkStatus(); - } - RETURN_IF_ERROR(CheckIdsAndValuesSize(name_vector, "names")); - absl::flat_hash_map used_variable_names; - if (check_unique) { - for (const auto [id, name_pointer] : name_vector) { - const std::string& name = *name_pointer; - if (!name.empty()) { - if (!gtl::InsertIfNotPresent(&used_variable_names, {name, id})) { - return absl::InvalidArgumentError( - absl::StrCat("Found name: ", name, " twice, for ids ", id, - " and ", used_variable_names.at(name))); - } - } - } - } - return absl::OkStatus(); -} - -absl::Status CheckNewNames( - const IdNameBiMap& old_names, - const SparseVectorView& new_names) { - if (old_names.Empty()) { - return absl::OkStatus(); - } - for (const auto [id, name_pointer] : new_names) { - const std::string& new_name = *name_pointer; - if (!new_name.empty() && old_names.HasName(new_name)) { - return absl::InvalidArgumentError( - absl::StrCat("Found name: ", new_name, " twice, for ids ", id, - " and ", old_names.nonempty_name_to_id().at(new_name))); - } - } - return absl::OkStatus(); -} - -} // namespace math_opt -} // namespace operations_research diff --git a/ortools/math_opt/validators/name_validator.h b/ortools/math_opt/validators/name_validator.h deleted file mode 100644 index e00224ecf0..0000000000 --- a/ortools/math_opt/validators/name_validator.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2010-2021 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef OR_TOOLS_MATH_OPT_VALIDATORS_NAME_VALIDATOR_H_ -#define OR_TOOLS_MATH_OPT_VALIDATORS_NAME_VALIDATOR_H_ - -#include - -#include "absl/status/status.h" -#include "ortools/math_opt/core/model_summary.h" -#include "ortools/math_opt/core/sparse_vector_view.h" - -namespace operations_research { -namespace math_opt { - -// Checks basic validity of name_vector view: i.e. ids_size() = values_size(). -// In addition, if check_unique is set to true, the function checks that every -// name that is not "" is distinct. -absl::Status CheckNameVector( - const SparseVectorView& name_vector, bool check_unique); - -// Checks new_names are compatible with old_names: i.e. new_names does not -// duplicate names in old_names. Assumes basic validity of new_names view and -// does not check for duplicates within old_names or new_names. -absl::Status CheckNewNames( - const IdNameBiMap& old_names, - const SparseVectorView& new_names); - -} // namespace math_opt -} // namespace operations_research - -#endif // OR_TOOLS_MATH_OPT_VALIDATORS_NAME_VALIDATOR_H_ diff --git a/ortools/math_opt/validators/result_validator.cc b/ortools/math_opt/validators/result_validator.cc index c404c5ed4c..3e629ad340 100644 --- a/ortools/math_opt/validators/result_validator.cc +++ b/ortools/math_opt/validators/result_validator.cc @@ -31,33 +31,6 @@ namespace operations_research { namespace math_opt { namespace { -absl::Status ValidateTermination(const TerminationProto& termination) { - if (termination.reason() == TERMINATION_REASON_UNSPECIFIED) { - return absl::InvalidArgumentError("termination reason must be specified"); - } - if (termination.reason() == TERMINATION_REASON_FEASIBLE || - termination.reason() == TERMINATION_REASON_NO_SOLUTION_FOUND) { - if (termination.limit() == LIMIT_UNSPECIFIED) { - return absl::InvalidArgumentError( - absl::StrCat("for reason ", ProtoEnumToString(termination.reason()), - ", limit must be specified")); - } - if (termination.limit() == LIMIT_CUTOFF && - termination.reason() == TERMINATION_REASON_FEASIBLE) { - return absl::InvalidArgumentError( - "For LIMIT_CUTOFF expected no solutions"); - } - } else { - if (termination.limit() != LIMIT_UNSPECIFIED) { - return absl::InvalidArgumentError( - absl::StrCat("for reason:", ProtoEnumToString(termination.reason()), - ", limit should be unspecified, but was set to: ", - ProtoEnumToString(termination.limit()))); - } - } - return absl::OkStatus(); -} - bool HasPrimalFeasibleSolution(const SolutionProto& solution) { return solution.has_primal_solution() && solution.primal_solution().feasibility_status() == @@ -146,6 +119,33 @@ absl::Status RequireNoDualFeasibleSolution(const SolveResultProto& result) { } } // namespace +absl::Status ValidateTermination(const TerminationProto& termination) { + if (termination.reason() == TERMINATION_REASON_UNSPECIFIED) { + return absl::InvalidArgumentError("termination reason must be specified"); + } + if (termination.reason() == TERMINATION_REASON_FEASIBLE || + termination.reason() == TERMINATION_REASON_NO_SOLUTION_FOUND) { + if (termination.limit() == LIMIT_UNSPECIFIED) { + return absl::InvalidArgumentError( + absl::StrCat("for reason ", ProtoEnumToString(termination.reason()), + ", limit must be specified")); + } + if (termination.limit() == LIMIT_CUTOFF && + termination.reason() == TERMINATION_REASON_FEASIBLE) { + return absl::InvalidArgumentError( + "For LIMIT_CUTOFF expected no solutions"); + } + } else { + if (termination.limit() != LIMIT_UNSPECIFIED) { + return absl::InvalidArgumentError( + absl::StrCat("for reason:", ProtoEnumToString(termination.reason()), + ", limit should be unspecified, but was set to: ", + ProtoEnumToString(termination.limit()))); + } + } + return absl::OkStatus(); +} + absl::Status CheckHasPrimalSolution(const SolveResultProto& result) { if (!HasPrimalFeasibleSolution(result)) { return absl::InvalidArgumentError( diff --git a/ortools/math_opt/validators/result_validator.h b/ortools/math_opt/validators/result_validator.h index ff4b3ce667..d0504829e3 100644 --- a/ortools/math_opt/validators/result_validator.h +++ b/ortools/math_opt/validators/result_validator.h @@ -22,6 +22,13 @@ namespace operations_research { namespace math_opt { +// Checks that: +// * termination.reason is not UNSPECIFIED, +// * termination.limit is set (not UNSPECIFIED) iff termination.reason is +// either FEASIBLE or NO_SOLUTION_FOUND, +// * termination.limit is not CUTOFF when termination.reason is FEASIBLE. +absl::Status ValidateTermination(const TerminationProto& termination); + // Validates the input result. absl::Status ValidateResult(const SolveResultProto& result, const ModelSolveParametersProto& parameters, diff --git a/ortools/math_opt/validators/sparse_matrix_validator.cc b/ortools/math_opt/validators/sparse_matrix_validator.cc index 65992d85d0..8d8866298d 100644 --- a/ortools/math_opt/validators/sparse_matrix_validator.cc +++ b/ortools/math_opt/validators/sparse_matrix_validator.cc @@ -88,13 +88,12 @@ absl::Status SparseMatrixValid(const SparseDoubleMatrixProto& matrix, return absl::OkStatus(); } -absl::Status SparseMatrixIdsAreKnown( - const SparseDoubleMatrixProto& matrix, - const absl::Span row_ids, - const absl::Span column_ids) { - RETURN_IF_ERROR(CheckSortedIdsSubset(matrix.row_ids(), row_ids)) +absl::Status SparseMatrixIdsAreKnown(const SparseDoubleMatrixProto& matrix, + const IdNameBiMap& row_ids, + const IdNameBiMap& column_ids) { + RETURN_IF_ERROR(CheckIdsSubset(matrix.row_ids(), row_ids)) << "Unknown row_id"; - RETURN_IF_ERROR(CheckUnsortedIdsSubset(matrix.column_ids(), column_ids)) + RETURN_IF_ERROR(CheckIdsSubset(matrix.column_ids(), column_ids)) << "Unknown column_id"; return absl::OkStatus(); } diff --git a/ortools/math_opt/validators/sparse_matrix_validator.h b/ortools/math_opt/validators/sparse_matrix_validator.h index e4facea7a2..c08357584c 100644 --- a/ortools/math_opt/validators/sparse_matrix_validator.h +++ b/ortools/math_opt/validators/sparse_matrix_validator.h @@ -18,6 +18,7 @@ #include "absl/status/status.h" #include "absl/types/span.h" +#include "ortools/math_opt/core/model_summary.h" #include "ortools/math_opt/model.pb.h" namespace operations_research::math_opt { @@ -36,8 +37,8 @@ absl::Status SparseMatrixValid(const SparseDoubleMatrixProto& matrix, // 1. matrix.row_ids is a subset of row_ids. // 2. matrix.column_ids is a subset of column_ids. absl::Status SparseMatrixIdsAreKnown(const SparseDoubleMatrixProto& matrix, - absl::Span row_ids, - absl::Span column_ids); + const IdNameBiMap& row_ids, + const IdNameBiMap& column_ids); } // namespace operations_research::math_opt