From 058618a9c44ebab22f634998e64aedba6da1b8e2 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 20 Sep 2021 15:24:56 +0200 Subject: [PATCH] fix more corner cases bugs --- ortools/sat/cp_model.proto | 4 ++ ortools/sat/cp_model_checker.cc | 30 +++++++++++-- ortools/sat/cp_model_presolve.cc | 76 ++++++++++++++++++++------------ ortools/sat/cp_model_presolve.h | 17 ++++--- 4 files changed, 90 insertions(+), 37 deletions(-) diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index 573c7f3593..638f45f6f0 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -401,6 +401,10 @@ message ConstraintProto { // variables. By convention, because we can just remove term equal to one, // the empty product forces the target to be one. // + // Note that the solver checks for potential integer overflow. So it is + // recommended to limit the domain of the variables such that the product + // fits in [INT_MIN + 1..INT_MAX - 1]. + // // TODO(user): Support more than two terms in the product. IntegerArgumentProto int_prod = 11; diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index e58798f565..acdf8503bd 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -319,6 +319,19 @@ std::string ValidateIntProdConstraint(const CpModelProto& model, return absl::StrCat("An int_prod constraint should have exactly 2 terms: ", ProtobufShortDebugString(ct)); } + + // Detect potential overflow if some of the variables span across 0. + const Domain product_domain = + ReadDomainFromProto(model.variables(PositiveRef(ct.int_prod().vars(0)))) + .ContinuousMultiplicationBy(ReadDomainFromProto( + model.variables(PositiveRef(ct.int_prod().vars(1))))); + if ((product_domain.Max() == std::numeric_limits::max() && + product_domain.Min() < 0) || + (product_domain.Min() == std::numeric_limits::min() && + product_domain.Max() > 0)) { + return absl::StrCat("Potential integer overflow in constraint: ", + ProtobufShortDebugString(ct)); + } return ""; } @@ -382,7 +395,7 @@ std::string ValidateAutomatonConstraint(const CpModelProto& model, } template -std::string ValidateGraphInput(const CpModelProto& model, +std::string ValidateGraphInput(bool is_route, const CpModelProto& model, const GraphProto& graph) { const int size = graph.tails().size(); if (graph.heads().size() != size || graph.literals().size() != size) { @@ -400,6 +413,10 @@ std::string ValidateGraphInput(const CpModelProto& model, "node ", graph.heads(i)); } + if (is_route && graph.tails(i) == 0) { + return absl::StrCat( + "A route constraint cannot have a self-loop on the depot (node 0)"); + } } return ""; @@ -410,10 +427,16 @@ std::string ValidateRoutesConstraint(const CpModelProto& model, int max_node = 0; absl::flat_hash_set nodes; for (const int node : ct.routes().tails()) { + if (node < 0) { + return "All node in a route constraint must be in [0, num_nodes)"; + } nodes.insert(node); max_node = std::max(max_node, node); } for (const int node : ct.routes().heads()) { + if (node < 0) { + return "All node in a route constraint must be in [0, num_nodes)"; + } nodes.insert(node); max_node = std::max(max_node, node); } @@ -422,7 +445,7 @@ std::string ValidateRoutesConstraint(const CpModelProto& model, "All nodes in a route constraint must have incident arcs"); } - return ValidateGraphInput(model, ct.routes()); + return ValidateGraphInput(/*is_route=*/true, model, ct.routes()); } std::string ValidateDomainIsPositive(const CpModelProto& model, int ref, @@ -863,7 +886,8 @@ std::string ValidateCpModel(const CpModelProto& model) { RETURN_IF_NOT_EMPTY(ValidateAutomatonConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kCircuit: - RETURN_IF_NOT_EMPTY(ValidateGraphInput(model, ct.circuit())); + RETURN_IF_NOT_EMPTY( + ValidateGraphInput(/*is_route=*/false, model, ct.circuit())); break; case ConstraintProto::ConstraintCase::kRoutes: RETURN_IF_NOT_EMPTY(ValidateRoutesConstraint(model, ct)); diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 4decfd3b48..121ed32dd1 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -30,9 +30,9 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/random/random.h" #include "absl/strings/str_join.h" -#include "ortools/base/hash.h" #include "ortools/base/integral_types.h" #include "ortools/base/logging.h" #include "ortools/base/map_util.h" @@ -6867,22 +6867,33 @@ bool CpModelPresolver::Presolve() { // // TODO(user): We might want to do that earlier so that our count of variable // usage is not biased by duplicate constraints. - const std::vector duplicates = + const std::vector> duplicates = FindDuplicateConstraints(*context_->working_model); - if (!duplicates.empty()) { - for (const int c : duplicates) { - const int type = - context_->working_model->constraints(c).constraint_case(); - if (type == ConstraintProto::ConstraintCase::kInterval) { - // TODO(user): we could delete duplicate identical interval, but we need - // to make sure reference to them are updated. - continue; - } - - context_->working_model->mutable_constraints(c)->Clear(); - context_->UpdateConstraintVariableUsage(c); - context_->UpdateRuleStats("removed duplicate constraints"); + for (const auto [dup, rep] : duplicates) { + const int type = + context_->working_model->constraints(dup).constraint_case(); + if (type == ConstraintProto::ConstraintCase::kInterval) { + // TODO(user): we could delete duplicate identical interval, but we need + // to make sure reference to them are updated. + continue; } + + if (type == ConstraintProto::kLinear) { + const Domain d1 = ReadDomainFromProto( + context_->working_model->constraints(rep).linear()); + const Domain d2 = ReadDomainFromProto( + context_->working_model->constraints(dup).linear()); + if (d1 != d2) { + context_->UpdateRuleStats("duplicate: merged rhs of linear constraint"); + FillDomainInProto(d1.IntersectionWith(d2), + context_->working_model->mutable_constraints(rep) + ->mutable_linear()); + } + } + + context_->working_model->mutable_constraints(dup)->Clear(); + context_->UpdateConstraintVariableUsage(dup); + context_->UpdateRuleStats("duplicate: removed constraint"); } if (context_->ModelIsUnsat()) { @@ -7133,8 +7144,20 @@ void ApplyVariableMapping(const std::vector& mapping, } } -std::vector FindDuplicateConstraints(const CpModelProto& model_proto) { - std::vector result; +namespace { +ConstraintProto CopyConstraintForDuplicateDetection(const ConstraintProto& ct) { + ConstraintProto copy = ct; + copy.clear_name(); + if (ct.constraint_case() == ConstraintProto::kLinear) { + copy.mutable_linear()->clear_domain(); + } + return copy; +} +} // namespace + +std::vector> FindDuplicateConstraints( + const CpModelProto& model_proto) { + std::vector> result; // We use a map hash: serialized_constraint_proto hash -> constraint index. ConstraintProto copy; @@ -7144,26 +7167,25 @@ std::vector FindDuplicateConstraints(const CpModelProto& model_proto) { const int num_constraints = model_proto.constraints().size(); for (int c = 0; c < num_constraints; ++c) { if (model_proto.constraints(c).constraint_case() == - ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET) { + ConstraintProto::CONSTRAINT_NOT_SET) { continue; } // We ignore names when comparing constraints. // // TODO(user): This is not particularly efficient. - copy = model_proto.constraints(c); - copy.clear_name(); + copy = CopyConstraintForDuplicateDetection(model_proto.constraints(c)); s = copy.SerializeAsString(); - const int64_t hash = std::hash()(s); - const auto insert = equiv_constraints.insert({hash, c}); - if (!insert.second) { + const int64_t hash = absl::Hash()(s); + const auto [it, inserted] = equiv_constraints.insert({hash, c}); + if (!inserted) { // Already present! - const int other_c_with_same_hash = insert.first->second; - copy = model_proto.constraints(other_c_with_same_hash); - copy.clear_name(); + const int other_c_with_same_hash = it->second; + copy = CopyConstraintForDuplicateDetection( + model_proto.constraints(other_c_with_same_hash)); if (s == copy.SerializeAsString()) { - result.push_back(c); + result.push_back({c, other_c_with_same_hash}); } } } diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index f44d376cd3..d66b521ad6 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -263,16 +263,19 @@ void CopyEverythingExceptVariablesAndConstraintsFieldsIntoContext( bool PresolveCpModel(PresolveContext* context, std::vector* postsolve_mapping); -// Returns the index of exact duplicate constraints in the given proto. That -// is, all returned constraints will have an identical constraint before it in -// the model_proto.constraints() list. Empty constraints are ignored. +// Returns the index of duplicate constraints in the given proto in the first +// element of each pair. The second element of each pair is the "representative" +// that is the first constraint in the proto in a set of duplicate constraints. +// +// Empty constraints are ignored. We also do a bit more: +// - We ignore names when comparing constraint. +// - For linear constraints, we ignore the domain. This is because we can +// just merge them if the constraints are the same. // // Visible here for testing. This is meant to be called at the end of the // presolve where constraints have been canonicalized. -// -// TODO(user): Ignore names? canonicalize constraint further by sorting -// enforcement literal list for instance... -std::vector FindDuplicateConstraints(const CpModelProto& model_proto); +std::vector> FindDuplicateConstraints( + const CpModelProto& model_proto); } // namespace sat } // namespace operations_research