From c14e54cf82442a6ca327b99a070f3a5d50b8ef9d Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Fri, 20 Jun 2025 15:11:37 +0200 Subject: [PATCH] [CP-SAT] print a solution after a SIGTERM; improve precedences --- ortools/sat/2d_distances_propagator.cc | 35 ++- ortools/sat/BUILD.bazel | 6 +- ortools/sat/cp_model_mapping.h | 1 - ortools/sat/cp_model_solver.cc | 45 +-- ortools/sat/cp_model_solver_helpers.cc | 94 +++++- ortools/sat/cp_model_solver_helpers.h | 8 + ortools/sat/integer_base.cc | 37 +-- ortools/sat/integer_base.h | 10 +- ortools/sat/precedences.cc | 402 ++++++++++++------------- ortools/sat/precedences.h | 269 +++++++++++------ ortools/sat/precedences_test.cc | 17 +- ortools/sat/sat_parameters.proto | 9 +- ortools/sat/sat_runner.cc | 180 ++++++++--- ortools/sat/synchronization.cc | 132 ++++++-- ortools/sat/synchronization.h | 99 +++++- ortools/sat/synchronization_test.cc | 20 +- ortools/util/sigint.cc | 36 ++- ortools/util/sigint.h | 20 +- ortools/util/sorted_interval_list.h | 4 +- 19 files changed, 934 insertions(+), 490 deletions(-) diff --git a/ortools/sat/2d_distances_propagator.cc b/ortools/sat/2d_distances_propagator.cc index 2053e29581..3d455420a5 100644 --- a/ortools/sat/2d_distances_propagator.cc +++ b/ortools/sat/2d_distances_propagator.cc @@ -13,6 +13,7 @@ #include "ortools/sat/2d_distances_propagator.h" +#include #include #include #include @@ -69,10 +70,12 @@ void Precedences2DPropagator::UpdateVarLookups() { void Precedences2DPropagator::CollectNewPairsOfBoxesWithNonTrivialDistance() { const absl::Span exprs = non_trivial_bounds_->GetLinear2WithPotentialNonTrivalBounds(); - if (exprs.size() != num_known_linear2_) { - VLOG(2) << "CollectPairsOfBoxesWithNonTrivialDistance called, num_exprs: " - << exprs.size(); + if (exprs.size() == num_known_linear2_) { + return; } + VLOG(2) << "CollectPairsOfBoxesWithNonTrivialDistance called, num_exprs: " + << exprs.size(); + const int previous_num_pairs = non_trivial_pairs_.size(); for (; num_known_linear2_ < exprs.size(); ++num_known_linear2_) { const LinearExpression2& positive_expr = exprs[num_known_linear2_]; LinearExpression2 negated_expr = positive_expr; @@ -111,7 +114,31 @@ void Precedences2DPropagator::CollectNewPairsOfBoxesWithNonTrivialDistance() { } } - gtl::STLSortAndRemoveDuplicates(&non_trivial_pairs_); + // Sort the new pairs. + std::sort(non_trivial_pairs_.begin() + previous_num_pairs, + non_trivial_pairs_.end()); + + // Remove duplicates from new pairs. + non_trivial_pairs_.erase( + std::unique(non_trivial_pairs_.begin() + previous_num_pairs, + non_trivial_pairs_.end()), + non_trivial_pairs_.end()); + + // Merge with the old pairs keeping sorted. + std::inplace_merge(non_trivial_pairs_.begin(), + non_trivial_pairs_.begin() + previous_num_pairs, + non_trivial_pairs_.end()); + + // Remove newly-added duplicates. + non_trivial_pairs_.erase( + std::unique(non_trivial_pairs_.begin(), non_trivial_pairs_.end()), + non_trivial_pairs_.end()); + + // Result should be sorted and without duplicates. + DCHECK(std::is_sorted(non_trivial_pairs_.begin(), non_trivial_pairs_.end())); + DCHECK(std::adjacent_find(non_trivial_pairs_.begin(), + non_trivial_pairs_.end()) == + non_trivial_pairs_.end()); } bool Precedences2DPropagator::Propagate() { diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 771c6c010a..4dc706800d 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -815,7 +815,6 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_utils", - ":integer", ":integer_base", ":linear_constraint", ":model", @@ -2056,6 +2055,7 @@ cc_library( deps = [ ":clause", ":cp_constraints", + ":cp_model_mapping", ":integer", ":integer_base", ":model", @@ -4023,13 +4023,17 @@ cc_binary( "//ortools/base:path", "//ortools/util:file_util", "//ortools/util:logging", + "//ortools/util:sigint", "//ortools/util:sorted_interval_list", + "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/flags:flag", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:flags", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/synchronization", + "@abseil-cpp//absl/types:span", "@protobuf", ], ) diff --git a/ortools/sat/cp_model_mapping.h b/ortools/sat/cp_model_mapping.h index 1a82e4263e..5cf63e3e2f 100644 --- a/ortools/sat/cp_model_mapping.h +++ b/ortools/sat/cp_model_mapping.h @@ -24,7 +24,6 @@ #include "ortools/base/strong_vector.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" -#include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 647d15efb0..e4c60ab3de 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -793,40 +793,6 @@ void LogSubsolverNames(absl::Span> subsolvers, SOLVER_LOG(logger, ""); } -void LogFinalStatistics(SharedClasses* shared) { - if (!shared->logger->LoggingIsEnabled()) return; - - shared->logger->FlushPendingThrottledLogs(/*ignore_rates=*/true); - SOLVER_LOG(shared->logger, ""); - - shared->stat_tables->Display(shared->logger); - shared->response->DisplayImprovementStatistics(); - - std::vector> table; - table.push_back({"Solution repositories", "Added", "Queried", "Synchro"}); - shared->response->SolutionPool().AddTableStats(&table); - table.push_back(shared->ls_hints->TableLineStats()); - if (shared->lp_solutions != nullptr) { - table.push_back(shared->lp_solutions->TableLineStats()); - } - if (shared->incomplete_solutions != nullptr) { - table.push_back(shared->incomplete_solutions->TableLineStats()); - } - SOLVER_LOG(shared->logger, FormatTable(table)); - - if (shared->bounds) { - shared->bounds->LogStatistics(shared->logger); - } - - if (shared->clauses) { - shared->clauses->LogStatistics(shared->logger); - } - - // Extra logging if needed. Note that these are mainly activated on - // --vmodule *some_file*=1 and are here for development. - shared->stats->Log(shared->logger); -} - void LaunchSubsolvers(const SatParameters& params, SharedClasses* shared, std::vector>& subsolvers, absl::Span ignored) { @@ -868,7 +834,7 @@ void LaunchSubsolvers(const SatParameters& params, SharedClasses* shared, for (int i = 0; i < subsolvers.size(); ++i) { subsolvers[i].reset(); } - LogFinalStatistics(shared); + shared->LogFinalStatistics(); } bool VarIsFixed(const CpModelProto& model_proto, int i) { @@ -1124,13 +1090,18 @@ class FullProblemSolver : public SubSolver { shared_->model_proto, shared_->bounds.get(), &local_model_); } + if (shared_->linear2_bounds != nullptr) { + RegisterLinear2BoundsImport(shared_->linear2_bounds.get(), + &local_model_); + } + // Note that this is done after the loading, so we will never export // problem clauses. if (shared_->clauses != nullptr) { const int id = shared_->clauses->RegisterNewId( + local_model_.Name(), /*may_terminate_early=*/stop_at_first_solution_ && - local_model_.GetOrCreate()->has_objective()); - shared_->clauses->SetWorkerNameForId(id, local_model_.Name()); + local_model_.GetOrCreate()->has_objective()); RegisterClausesLevelZeroImport(id, shared_->clauses.get(), &local_model_); diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 083d657587..3e22dabbf5 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -847,6 +847,59 @@ void RegisterVariableBoundsLevelZeroImport( import_level_zero_bounds); } +void RegisterLinear2BoundsImport(SharedLinear2Bounds* shared_linear2_bounds, + Model* model) { + CHECK(shared_linear2_bounds != nullptr); + auto* cp_model_mapping = model->GetOrCreate(); + auto* root_linear2 = model->GetOrCreate(); + auto* sat_solver = model->GetOrCreate(); + const int import_id = + shared_linear2_bounds->RegisterNewImportId(model->Name()); + const auto& import_function = [import_id, shared_linear2_bounds, root_linear2, + cp_model_mapping, sat_solver, model]() { + const auto new_bounds = + shared_linear2_bounds->NewlyUpdatedBounds(import_id); + int num_imported = 0; + for (const auto& [proto_expr, bounds] : new_bounds) { + // Lets create the corresponding LinearExpression2. + LinearExpression2 expr; + for (const int i : {0, 1}) { + expr.vars[i] = cp_model_mapping->Integer(proto_expr.vars[i]); + expr.coeffs[i] = proto_expr.coeffs[i]; + } + const auto [lb, ub] = bounds; + const auto [lb_added, ub_added] = root_linear2->Add(expr, lb, ub); + if (!lb_added && !ub_added) continue; + ++num_imported; + + // TODO(user): Is it a good idea to add the linear constraint ? + // We might have many redundant linear2 relations that don't need + // propagation when we have chains of precedences. The root_linear2 should + // be up-to-date with transitive closure to avoid adding such relations + // (recompute it at level zero before this?). + // + // TODO(user): use IntegerValure directly in + // AddWeightedSumGreaterOrEqual() or use a lower-level API. + const std::vector coeffs = {expr.coeffs[0].value(), + expr.coeffs[1].value()}; + if (lb_added) { + AddWeightedSumGreaterOrEqual({}, absl::MakeSpan(expr.vars, 2), coeffs, + lb.value(), model); + if (sat_solver->ModelIsUnsat()) return false; + } + if (ub_added) { + AddWeightedSumLowerOrEqual({}, absl::MakeSpan(expr.vars, 2), coeffs, + ub.value(), model); + if (sat_solver->ModelIsUnsat()) return false; + } + } + shared_linear2_bounds->NotifyNumImported(import_id, num_imported); + return true; + }; + model->GetOrCreate()->callbacks.push_back( + import_function); +} + // Registers a callback that will report improving objective best bound. // It will be called each time new objective bound are propagated at level zero. void RegisterObjectiveBestBoundExport( @@ -2086,6 +2139,10 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model) bounds->LoadDebugSolution(response->DebugSolution()); } + if (params.share_linear2_bounds()) { + linear2_bounds = std::make_unique(); + } + // Create extra shared classes if needed. Note that while these parameters // are true by default, we disable them if we don't have enough workers for // them in AdaptGlobalParameters(). @@ -2120,7 +2177,7 @@ void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) { local_model->Register(stat_tables); // TODO(user): Use parameters and not the presence/absence of these class - // to decide when to use them. + // to decide when to use them? this is not clear. if (lp_solutions != nullptr) { local_model->Register(lp_solutions.get()); } @@ -2134,6 +2191,9 @@ void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) { if (clauses != nullptr) { local_model->Register(clauses.get()); } + if (linear2_bounds != nullptr) { + local_model->Register(linear2_bounds.get()); + } } bool SharedClasses::SearchIsDone() { @@ -2146,5 +2206,37 @@ bool SharedClasses::SearchIsDone() { return false; } +void SharedClasses::LogFinalStatistics() { + if (!logger->LoggingIsEnabled()) return; + + logger->FlushPendingThrottledLogs(/*ignore_rates=*/true); + SOLVER_LOG(logger, ""); + + stat_tables->Display(logger); + response->DisplayImprovementStatistics(); + + std::vector> table; + table.push_back({"Solution repositories", "Added", "Queried", "Synchro"}); + response->SolutionPool().AddTableStats(&table); + table.push_back(ls_hints->TableLineStats()); + if (lp_solutions != nullptr) { + table.push_back(lp_solutions->TableLineStats()); + } + if (incomplete_solutions != nullptr) { + table.push_back(incomplete_solutions->TableLineStats()); + } + SOLVER_LOG(logger, FormatTable(table)); + + // TODO(user): we can combine the "bounds table" into one for shorter logs. + if (bounds != nullptr) bounds->LogStatistics(logger); + if (linear2_bounds != nullptr) linear2_bounds->LogStatistics(logger); + + if (clauses != nullptr) clauses->LogStatistics(logger); + + // Extra logging if needed. Note that these are mainly activated on + // --vmodule *some_file*=1 and are here for development. + stats->Log(logger); +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index 1f46f77495..af00cb3213 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -60,12 +60,15 @@ struct SharedClasses { std::unique_ptr lp_solutions; std::unique_ptr incomplete_solutions; std::unique_ptr clauses; + std::unique_ptr linear2_bounds; // call local_model->Register() on most of the class here, this allow to // more easily depends on one of the shared class deep within the solver. void RegisterSharedClassesInLocalModel(Model* local_model); bool SearchIsDone(); + + void LogFinalStatistics(); }; // Loads a CpModelProto inside the given model. @@ -119,6 +122,11 @@ int RegisterClausesLevelZeroImport(int id, SharedClausesManager* shared_clauses_manager, Model* model); +// This will register a level zero callback to imports new linear2 from the +// SharedLinear2Bounds. +void RegisterLinear2BoundsImport(SharedLinear2Bounds* shared_linear2_bounds, + Model* model); + void PostsolveResponseWrapper(const SatParameters& params, int num_variable_in_original_model, const CpModelProto& mapping_proto, diff --git a/ortools/sat/integer_base.cc b/ortools/sat/integer_base.cc index 740c02bd0d..29a7d8d186 100644 --- a/ortools/sat/integer_base.cc +++ b/ortools/sat/integer_base.cc @@ -84,22 +84,18 @@ bool LinearExpression2::NegateForCanonicalization() { } bool LinearExpression2::CanonicalizeAndUpdateBounds(IntegerValue& lb, - IntegerValue& ub, - bool allow_negation) { + IntegerValue& ub) { SimpleCanonicalization(); if (coeffs[0] == 0 || coeffs[1] == 0) return false; // abort. - bool negated = false; - if (allow_negation) { - negated = NegateForCanonicalization(); - if (negated) { - // We need to be able to negate without overflow. - CHECK_GE(lb, kMinIntegerValue); - CHECK_LE(ub, kMaxIntegerValue); - std::swap(lb, ub); - lb = -lb; - ub = -ub; - } + const bool negated = NegateForCanonicalization(); + if (negated) { + // We need to be able to negate without overflow. + CHECK_GE(lb, kMinIntegerValue); + CHECK_LE(ub, kMaxIntegerValue); + std::swap(lb, ub); + lb = -lb; + ub = -ub; } // Do gcd division. @@ -144,8 +140,7 @@ std::pair BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb, IntegerValue ub) { - const bool negated = - expr.CanonicalizeAndUpdateBounds(lb, ub, /*allow_negation=*/true); + const bool negated = expr.CanonicalizeAndUpdateBounds(lb, ub); // We only store proper linear2. if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { @@ -184,7 +179,7 @@ BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb, RelationStatus BestBinaryRelationBounds::GetStatus(LinearExpression2 expr, IntegerValue lb, IntegerValue ub) const { - expr.CanonicalizeAndUpdateBounds(lb, ub, /*allow_negation=*/true); + expr.CanonicalizeAndUpdateBounds(lb, ub); if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { return RelationStatus::IS_UNKNOWN; } @@ -245,14 +240,4 @@ BestBinaryRelationBounds::GetSortedNonTrivialBounds() const { return root_relations_sorted; } -void BestBinaryRelationBounds::AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const { - for (const auto& [expr, unused] : best_bounds_) { - if (!var_set[PositiveVariable(expr.vars[0])]) continue; - if (!var_set[PositiveVariable(expr.vars[1])]) continue; - result->push_back(expr); - } -} - } // namespace operations_research::sat diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index 9eb30219cc..ba4f04cdff 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -384,8 +384,7 @@ struct LinearExpression2 { // accordingly. This is the same as SimpleCanonicalization(), DivideByGcd() // and the NegateForCanonicalization() with a proper updates of the bounds. // Returns whether the expression was negated. - bool CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub, - bool allow_negation = false); + bool CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub); // Divides the expression by the gcd of both coefficients, and returns it. // Note that we always return something >= 1 even if both coefficients are @@ -493,7 +492,7 @@ class BestBinaryRelationBounds { IntegerValue GetUpperBound(LinearExpression2 expr) const; // Same as GetUpperBound() but assume the expression is already canonicalized. - // This is slighlty faster. + // This is slightly faster. IntegerValue UpperBoundWhenCanonicalized(LinearExpression2 expr) const; int64_t num_bounds() const { return best_bounds_.size(); } @@ -504,11 +503,6 @@ class BestBinaryRelationBounds { std::vector> GetSortedNonTrivialBounds() const; - // Note that this is non-deterministic and in O(num_relations). - void AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const; - private: // The best bound on the given "canonicalized" expression. absl::flat_hash_map> diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index 017cc45550..5618fb304a 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -54,6 +55,53 @@ namespace operations_research { namespace sat { +LinearExpression2Index Linear2WithPotentialNonTrivalBounds::AddOrGet( + LinearExpression2 original_expr) { + LinearExpression2 expr = original_expr; + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); + DCHECK_NE(expr.coeffs[0], 0); + DCHECK_NE(expr.coeffs[1], 0); + const bool negated = expr.NegateForCanonicalization(); + auto [it, inserted] = expr_to_index_.insert({expr, exprs_.size()}); + if (inserted) { + CHECK_LT(2 * exprs_.size() + 1, + std::numeric_limits::max()); + exprs_.push_back(expr); + } + const LinearExpression2Index result = + negated ? NegationOf(LinearExpression2Index(2 * it->second)) + : LinearExpression2Index(2 * it->second); + + if (!inserted) return result; + + // Update our special coeff=1 lookup table. + if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) { + // +2 to handle possible negation. + const int new_size = + std::max(expr.vars[0].value(), expr.vars[1].value()) + 2; + if (new_size > coeff_one_var_lookup_.size()) { + coeff_one_var_lookup_.resize(new_size); + } + LinearExpression2 neg_expr = original_expr; + neg_expr.Negate(); + coeff_one_var_lookup_[original_expr.vars[0]].push_back(result); + coeff_one_var_lookup_[original_expr.vars[1]].push_back(result); + coeff_one_var_lookup_[neg_expr.vars[1]].push_back(NegationOf(result)); + coeff_one_var_lookup_[neg_expr.vars[0]].push_back(NegationOf(result)); + } + + // Update our per-variable and per-pair lookup tables. + IntegerVariable var1 = PositiveVariable(expr.vars[0]); + IntegerVariable var2 = PositiveVariable(expr.vars[1]); + if (var1 > var2) std::swap(var1, var2); + var_pair_to_bounds_[{var1, var2}].push_back(result); + var_to_bounds_[var1].push_back(result); + var_to_bounds_[var2].push_back(result); + + return result; +} + void Linear2Watcher::NotifyBoundChanged(LinearExpression2 expr) { DCHECK(expr.IsCanonicalized()); DCHECK_EQ(expr.DivideByGcd(), 1); @@ -75,115 +123,51 @@ int64_t Linear2Watcher::VarTimestamp(IntegerVariable var) { return var < var_timestamp_.size() ? var_timestamp_[var] : 0; } -std::pair RootLevelLinear2Bounds::Add(LinearExpression2 expr, - IntegerValue lb, - IntegerValue ub) { - using AddResult = BestBinaryRelationBounds::AddResult; - const IntegerValue zero_level_lb = integer_trail_->LevelZeroLowerBound(expr); +bool RootLevelLinear2Bounds::AddUpperBound(LinearExpression2Index index, + IntegerValue ub) { + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); const IntegerValue zero_level_ub = integer_trail_->LevelZeroUpperBound(expr); - if (lb <= zero_level_lb && ub >= zero_level_ub) { - return {false, false}; - } - // Don't store one of the bounds if it is trivial. - if (lb <= zero_level_lb) { - lb = kMinIntegerValue; - } if (ub >= zero_level_ub) { - ub = kMaxIntegerValue; + return false; } - expr.CanonicalizeAndUpdateBounds(lb, ub); - const auto [status_lb, status_ub] = root_level_relations_.Add(expr, lb, ub); + if (best_upper_bounds_.size() <= index) { + best_upper_bounds_.resize(index.value() + 1, kMaxIntegerValue); + } + if (ub >= best_upper_bounds_[index]) { + return false; + } + best_upper_bounds_[index] = ub; - const bool lb_restricted = - status_lb == AddResult::ADDED || status_lb == AddResult::UPDATED; - const bool ub_restricted = - status_ub == AddResult::ADDED || status_ub == AddResult::UPDATED; - if (!lb_restricted && !ub_restricted) return {false, false}; - - non_trivial_bounds_->AddOrGet(expr); ++num_updates_; linear2_watcher_->NotifyBoundChanged(expr); - // Update our special coeff=1 lookup table. - if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) { - // +2 to handle possible negation. - const int new_size = - std::max(expr.vars[0].value(), expr.vars[1].value()) + 2; - if (new_size > coeff_one_var_lookup_.size()) { - coeff_one_var_lookup_.resize(new_size); - } - if (status_lb == AddResult::ADDED) { - // First time added to root_level_relations_. - coeff_one_var_lookup_[NegationOf(expr.vars[0])].push_back( - NegationOf(expr.vars[1])); - coeff_one_var_lookup_[NegationOf(expr.vars[1])].push_back( - NegationOf(expr.vars[0])); - } - if (status_ub == AddResult::ADDED) { - coeff_one_var_lookup_[expr.vars[0]].push_back(expr.vars[1]); - coeff_one_var_lookup_[expr.vars[1]].push_back(expr.vars[0]); + // Share. + // + // TODO(user): It seems we could change the canonicalization to only use + // positive variable? that would simplify a bit the code here and not make it + // worse elsewhere? + if (shared_linear2_bounds_ != nullptr) { + const IntegerValue lb = -LevelZeroUpperBound(NegationOf(index)); + const int proto_var0 = + cp_model_mapping_->GetProtoVariableFromIntegerVariable( + PositiveVariable(expr.vars[0])); + const int proto_var1 = + cp_model_mapping_->GetProtoVariableFromIntegerVariable( + PositiveVariable(expr.vars[1])); + if (proto_var0 >= 0 && proto_var1 >= 0) { + // This is also a relation between cp_model proto variable. Share it! + // Note that since expr is canonicalized, this one should too. + SharedLinear2Bounds::Key key; + key.vars[0] = proto_var0; + key.coeffs[0] = + VariableIsPositive(expr.vars[0]) ? expr.coeffs[0] : -expr.coeffs[0]; + key.vars[1] = proto_var1; + key.coeffs[1] = + VariableIsPositive(expr.vars[1]) ? expr.coeffs[1] : -expr.coeffs[1]; + shared_linear2_bounds_->Add(shared_linear2_bounds_id_, key, lb, ub); } } - - // Update our per-variable and per-pair lookup tables. - IntegerVariable var1 = PositiveVariable(expr.vars[0]); - IntegerVariable var2 = PositiveVariable(expr.vars[1]); - if (var1 > var2) std::swap(var1, var2); - - auto [it_var, inserted] = var_to_bounds_vector_index_.insert({expr, {0, 0}}); - for (const IntegerVariable var : {var1, var2}) { - auto& var_bounds = var_to_bounds_[var]; - if (inserted) { - if (var == var1) { - it_var->second.first = var_bounds.size(); - } else { - it_var->second.second = var_bounds.size(); - } - var_bounds.push_back({expr, lb, ub}); - } else { - const int index = - (var == var1) ? it_var->second.first : it_var->second.second; - DCHECK_LT(index, var_bounds.size()); - std::tuple& var_bound = - var_bounds[index]; - if (status_lb == AddResult::ADDED || status_lb == AddResult::UPDATED) { - std::get<1>(var_bound) = lb; - } - if (status_ub == AddResult::ADDED || status_ub == AddResult::UPDATED) { - std::get<2>(var_bound) = ub; - } - } - } - - auto [it_pair, pair_inserted] = - var_pair_to_bounds_vector_index_.insert({expr, 0}); - DCHECK_EQ(inserted, pair_inserted); - auto& pair_bounds = var_pair_to_bounds_[{var1, var2}]; - if (pair_inserted) { - it_pair->second = pair_bounds.size(); - pair_bounds.push_back({expr, lb, ub}); - } else { - const int index = it_pair->second; - DCHECK_LT(index, pair_bounds.size()); - std::tuple& pair_bound = - pair_bounds[index]; - if (status_lb == AddResult::ADDED || status_lb == AddResult::UPDATED) { - std::get<1>(pair_bound) = lb; - } - if (status_ub == AddResult::ADDED || status_ub == AddResult::UPDATED) { - std::get<2>(pair_bound) = ub; - } - } - - return {lb_restricted, ub_restricted}; -} - -IntegerValue RootLevelLinear2Bounds::LevelZeroUpperBound( - LinearExpression2 expr) const { - // TODO(user): Remove the expression from the root_level_relations_ if the - // zero-level bound got more restrictive. - return std::min(integer_trail_->LevelZeroUpperBound(expr), - root_level_relations_.GetUpperBound(expr)); + return true; } RootLevelLinear2Bounds::~RootLevelLinear2Bounds() { @@ -209,38 +193,38 @@ RelationStatus RootLevelLinear2Bounds::GetLevelZeroStatus( } IntegerValue RootLevelLinear2Bounds::GetUpperBoundNoTrail( - LinearExpression2 expr) const { - DCHECK_EQ(expr.DivideByGcd(), 1); - DCHECK(expr.IsCanonicalized()); - return root_level_relations_.UpperBoundWhenCanonicalized(expr); + LinearExpression2Index index) const { + if (best_upper_bounds_.size() <= index) { + return kMaxIntegerValue; + } + return best_upper_bounds_[index]; } std::vector> RootLevelLinear2Bounds::GetSortedNonTrivialUpperBounds() const { - std::vector> result = - root_level_relations_.GetSortedNonTrivialUpperBounds(); - int new_size = 0; - for (int i = 0; i < result.size(); ++i) { - const auto& [expr, ub] = result[i]; + std::vector> result; + for (LinearExpression2Index index = LinearExpression2Index{0}; + index < best_upper_bounds_.size(); ++index) { + const IntegerValue ub = best_upper_bounds_[index]; + if (ub == kMaxIntegerValue) continue; + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); if (ub < integer_trail_->LevelZeroUpperBound(expr)) { - result[new_size] = {expr, ub}; - ++new_size; + result.push_back({expr, ub}); } } - result.resize(new_size); + std::sort(result.begin(), result.end()); return result; } -// Return a list of (lb <= expr <= ub), with expr.vars[0] = var, where at -// least one of the bounds is non-trivial and the potential other non-trivial -// bound is tight. std::vector> RootLevelLinear2Bounds::GetAllBoundsContainingVariable( IntegerVariable var) const { std::vector> result; - auto it = var_to_bounds_.find(PositiveVariable(var)); - if (it == var_to_bounds_.end()) return {}; - for (const auto& [expr, lb, ub] : it->second) { + for (const LinearExpression2Index index : + non_trivial_bounds_->GetAllLinear2ContainingVariable(var)) { + const IntegerValue lb = -GetUpperBoundNoTrail(NegationOf(index)); + const IntegerValue ub = GetUpperBoundNoTrail(index); + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); const IntegerValue trail_lb = integer_trail_->LevelZeroLowerBound(expr); const IntegerValue trail_ub = integer_trail_->LevelZeroUpperBound(expr); if (lb <= trail_lb && ub >= trail_ub) continue; @@ -271,12 +255,11 @@ std::vector> RootLevelLinear2Bounds::GetAllBoundsContainingVariables( IntegerVariable var1, IntegerVariable var2) const { std::vector> result; - std::pair key = {PositiveVariable(var1), - PositiveVariable(var2)}; - if (key.first > key.second) std::swap(key.first, key.second); - auto it = var_pair_to_bounds_.find(key); - if (it == var_pair_to_bounds_.end()) return {}; - for (const auto& [expr, lb, ub] : it->second) { + for (const LinearExpression2Index index : + non_trivial_bounds_->GetAllLinear2ContainingVariables(var1, var2)) { + const IntegerValue lb = -GetUpperBoundNoTrail(NegationOf(index)); + const IntegerValue ub = GetUpperBoundNoTrail(index); + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); const IntegerValue trail_lb = integer_trail_->LevelZeroLowerBound(expr); const IntegerValue trail_ub = integer_trail_->LevelZeroUpperBound(expr); if (lb <= trail_lb && ub >= trail_ub) continue; @@ -304,10 +287,25 @@ RootLevelLinear2Bounds::GetAllBoundsContainingVariables( return result; } -void RootLevelLinear2Bounds::AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const { - root_level_relations_.AppendAllExpressionContaining(var_set, result); +std::vector +RootLevelLinear2Bounds::GetVariablesInSimpleRelation( + IntegerVariable var) const { + std::vector result; + for (const LinearExpression2Index index : + non_trivial_bounds_->GetAllLinear2ContainingVariableWithCoeffOne(var)) { + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); + const IntegerVariable other = + (expr.vars[0] == var ? expr.vars[1] : expr.vars[0]); + DCHECK_EQ(expr.coeffs[0], 1); + DCHECK_EQ(expr.coeffs[1], 1); + DCHECK((expr.vars[0] == var && expr.vars[1] == other) || + (expr.vars[0] == other && expr.vars[1] == var)); + if (GetUpperBoundNoTrail(index) < + integer_trail_->LevelZeroUpperBound(expr)) { + result.push_back(other); + } + } + return result; } EnforcedLinear2Bounds::~EnforcedLinear2Bounds() { @@ -319,13 +317,8 @@ EnforcedLinear2Bounds::~EnforcedLinear2Bounds() { } void EnforcedLinear2Bounds::PushConditionalRelation( - absl::Span enforcements, LinearExpression2 expr, + absl::Span enforcements, LinearExpression2Index index, IntegerValue rhs) { - expr.SimpleCanonicalization(); - if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { - return; - } - // This must be currently true. if (DEBUG_MODE) { for (const Literal l : enforcements) { @@ -334,24 +327,25 @@ void EnforcedLinear2Bounds::PushConditionalRelation( } if (enforcements.empty() || trail_->CurrentDecisionLevel() == 0) { - root_level_bounds_->AddUpperBound(expr, rhs); + root_level_bounds_->AddUpperBound(index, rhs); return; } - const IntegerValue gcd = expr.DivideByGcd(); - rhs = FloorRatio(rhs, gcd); - - if (rhs >= root_level_bounds_->LevelZeroUpperBound(expr)) return; + if (rhs >= root_level_bounds_->LevelZeroUpperBound(index)) return; + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); linear2_watcher_->NotifyBoundChanged(expr); ++num_conditional_relation_updates_; const int new_index = conditional_stack_.size(); - const auto [it, inserted] = conditional_relations_.insert({expr, new_index}); - if (inserted) { - non_trivial_bounds_->AddOrGet(expr); + if (conditional_relations_.size() <= index) { + conditional_relations_.resize(index.value() + 1, -1); + } + if (conditional_relations_[index] == -1) { + conditional_relations_[index] = new_index; CreateLevelEntryIfNeeded(); - conditional_stack_.emplace_back(/*prev_entry=*/-1, rhs, expr, enforcements); + conditional_stack_.emplace_back(/*prev_entry=*/-1, rhs, index, + enforcements); if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) { const int new_size = @@ -363,13 +357,13 @@ void EnforcedLinear2Bounds::PushConditionalRelation( conditional_var_lookup_[expr.vars[1]].push_back(expr.vars[0]); } } else { - const int prev_entry = it->second; + const int prev_entry = conditional_relations_[index]; if (rhs >= conditional_stack_[prev_entry].rhs) return; // Update. - it->second = new_index; + conditional_relations_[index] = new_index; CreateLevelEntryIfNeeded(); - conditional_stack_.emplace_back(prev_entry, rhs, expr, enforcements); + conditional_stack_.emplace_back(prev_entry, rhs, index, enforcements); } } @@ -392,15 +386,15 @@ void EnforcedLinear2Bounds::SetLevel(int level) { if (back.prev_entry != -1) { conditional_relations_[back.key] = back.prev_entry; } else { - conditional_relations_.erase(back.key); + conditional_relations_[back.key] = -1; + const LinearExpression2 expr = + non_trivial_bounds_->GetExpression(back.key); - if (back.key.coeffs[0] == 1 && back.key.coeffs[1] == 1) { - DCHECK_EQ(conditional_var_lookup_[back.key.vars[0]].back(), - back.key.vars[1]); - DCHECK_EQ(conditional_var_lookup_[back.key.vars[1]].back(), - back.key.vars[0]); - conditional_var_lookup_[back.key.vars[0]].pop_back(); - conditional_var_lookup_[back.key.vars[1]].pop_back(); + if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) { + DCHECK_EQ(conditional_var_lookup_[expr.vars[0]].back(), expr.vars[1]); + DCHECK_EQ(conditional_var_lookup_[expr.vars[1]].back(), expr.vars[0]); + conditional_var_lookup_[expr.vars[0]].pop_back(); + conditional_var_lookup_[expr.vars[1]].pop_back(); } } conditional_stack_.pop_back(); @@ -410,42 +404,42 @@ void EnforcedLinear2Bounds::SetLevel(int level) { } void EnforcedLinear2Bounds::AddReasonForUpperBoundLowerThan( - LinearExpression2 expr, IntegerValue ub, + LinearExpression2Index index, IntegerValue ub, std::vector* literal_reason, std::vector* /*unused*/) const { - expr.SimpleCanonicalization(); - if (ub >= root_level_bounds_->LevelZeroUpperBound(expr)) return; - const IntegerValue gcd = expr.DivideByGcd(); - const auto it = conditional_relations_.find(expr); - DCHECK(it != conditional_relations_.end()); + if (ub >= root_level_bounds_->LevelZeroUpperBound(index)) return; + DCHECK_LT(index, conditional_relations_.size()); + const int entry_index = conditional_relations_[index]; + DCHECK_NE(entry_index, -1); - const ConditionalEntry& entry = conditional_stack_[it->second]; + const ConditionalEntry& entry = conditional_stack_[entry_index]; if (DEBUG_MODE) { for (const Literal l : entry.enforcements) { CHECK(trail_->Assignment().LiteralIsTrue(l)); } } - DCHECK_LE(CapProdI(gcd, entry.rhs), ub); + DCHECK_LE(entry.rhs, ub); for (const Literal l : entry.enforcements) { literal_reason->push_back(l.Negated()); } } IntegerValue EnforcedLinear2Bounds::GetUpperBoundFromEnforced( - LinearExpression2 expr) const { - DCHECK_EQ(expr.DivideByGcd(), 1); - DCHECK(expr.IsCanonicalized()); - const auto it = conditional_relations_.find(expr); - if (it == conditional_relations_.end()) { + LinearExpression2Index index) const { + if (index >= conditional_relations_.size()) { + return kMaxIntegerValue; + } + const int entry_index = conditional_relations_[index]; + if (entry_index == -1) { return kMaxIntegerValue; } else { - const ConditionalEntry& entry = conditional_stack_[it->second]; + const ConditionalEntry& entry = conditional_stack_[entry_index]; if (DEBUG_MODE) { for (const Literal l : entry.enforcements) { CHECK(trail_->Assignment().LiteralIsTrue(l)); } } - DCHECK_LT(entry.rhs, root_level_bounds_->LevelZeroUpperBound(expr)); + DCHECK_LT(entry.rhs, root_level_bounds_->LevelZeroUpperBound(index)); return entry.rhs; } } @@ -569,7 +563,7 @@ void TransitivePrecedencesEvaluator::Build() { } VLOG(2) << "Full precedences. Work=" << work - << " Relations=" << root_level_bounds_->num_bounds(); + << " Relations=" << root_relations_sorted.size(); } void TransitivePrecedencesEvaluator::ComputeFullPrecedences( @@ -738,16 +732,6 @@ void EnforcedLinear2Bounds::CollectPrecedences( } } -void EnforcedLinear2Bounds::AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const { - for (const auto& entry : conditional_stack_) { - if (!var_set[PositiveVariable(entry.key.vars[0])]) continue; - if (!var_set[PositiveVariable(entry.key.vars[1])]) continue; - result->push_back(entry.key); - } -} - namespace { void AppendLowerBoundReasonIfValid(IntegerVariable var, @@ -1828,6 +1812,7 @@ Linear2BoundsFromLinear3::Linear2BoundsFromLinear3(Model* model) bool Linear2BoundsFromLinear3::AddAffineUpperBound(LinearExpression2 expr, AffineExpression affine_ub) { expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false; // At level zero, just add it to root_level_bounds_. if (trail_->CurrentDecisionLevel() == 0) { @@ -1900,16 +1885,6 @@ void Linear2BoundsFromLinear3::AddReasonForUpperBoundLowerThan( integer_reason->push_back(affine.LowerOrEqual(CapProdI(ub + 1, divisor) - 1)); } -void Linear2BoundsFromLinear3::AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const { - for (const auto& [expr, unused] : best_affine_ub_) { - if (!var_set[PositiveVariable(expr.vars[0])]) continue; - if (!var_set[PositiveVariable(expr.vars[1])]) continue; - result->push_back(expr); - } -} - IntegerValue Linear2Bounds::UpperBound(LinearExpression2 expr) const { expr.SimpleCanonicalization(); if (expr.coeffs[0] == 0) { @@ -1918,8 +1893,11 @@ IntegerValue Linear2Bounds::UpperBound(LinearExpression2 expr) const { DCHECK_NE(expr.coeffs[1], 0); const IntegerValue gcd = expr.DivideByGcd(); IntegerValue ub = integer_trail_->UpperBound(expr); - ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(expr)); - ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr)); + const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr); + if (index != kNoLinearExpression2Index) { + ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(index)); + ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(index)); + } ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr)); return CapProdI(gcd, ub); } @@ -1932,8 +1910,12 @@ IntegerValue Linear2Bounds::NonTrivialUpperBoundForGcd1( } DCHECK_NE(expr.coeffs[1], 0); DCHECK_EQ(1, expr.DivideByGcd()); - IntegerValue ub = root_level_bounds_->GetUpperBoundNoTrail(expr); - ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr)); + IntegerValue ub = kMaxIntegerValue; + const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr); + if (index != kNoLinearExpression2Index) { + ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(index)); + ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(index)); + } ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr)); return ub; } @@ -1942,20 +1924,25 @@ void Linear2Bounds::AddReasonForUpperBoundLowerThan( LinearExpression2 expr, IntegerValue ub, std::vector* literal_reason, std::vector* integer_reason) const { - expr.SimpleCanonicalization(); - const IntegerValue gcd = expr.DivideByGcd(); - ub = FloorRatio(ub, gcd); DCHECK_LE(UpperBound(expr), ub); // Explanation are by order of preference, with no reason needed first. - if (root_level_bounds_->LevelZeroUpperBound(expr) <= ub) { + if (integer_trail_->LevelZeroUpperBound(expr) <= ub) { return; } - + expr.SimpleCanonicalization(); + const IntegerValue gcd = expr.DivideByGcd(); + ub = FloorRatio(ub, gcd); + const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr); // This one is a single literal. - if (enforced_bounds_->GetUpperBoundFromEnforced(expr) <= ub) { - return enforced_bounds_->AddReasonForUpperBoundLowerThan( - expr, ub, literal_reason, integer_reason); + if (index != kNoLinearExpression2Index) { + if (root_level_bounds_->GetUpperBoundNoTrail(index) <= ub) { + return; + } + if (enforced_bounds_->GetUpperBoundFromEnforced(index) <= ub) { + return enforced_bounds_->AddReasonForUpperBoundLowerThan( + index, ub, literal_reason, integer_reason); + } } // This one is a single var upper bound. @@ -1975,16 +1962,5 @@ void Linear2Bounds::AddReasonForUpperBoundLowerThan( integer_reason); } -absl::Span -Linear2Bounds::GetAllExpressionsWithPotentialNonTrivialBounds( - Bitset64::ConstView var_set) const { - tmp_expressions_.clear(); - root_level_bounds_->AppendAllExpressionContaining(var_set, &tmp_expressions_); - enforced_bounds_->AppendAllExpressionContaining(var_set, &tmp_expressions_); - linear3_bounds_->AppendAllExpressionContaining(var_set, &tmp_expressions_); - gtl::STLSortAndRemoveDuplicates(&tmp_expressions_); - return tmp_expressions_; -} - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 392943ce63..586b28dd89 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -14,10 +14,10 @@ #ifndef OR_TOOLS_SAT_PRECEDENCES_H_ #define OR_TOOLS_SAT_PRECEDENCES_H_ +#include #include #include #include -#include #include #include #include @@ -31,6 +31,7 @@ #include "absl/types/span.h" #include "ortools/base/strong_vector.h" #include "ortools/graph/graph.h" +#include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" @@ -70,23 +71,14 @@ class Linear2WithPotentialNonTrivalBounds { // Returns a never-changing index for the given linear expression. // The expression must already be canonicalized and divided by its GCD. - LinearExpression2Index AddOrGet(LinearExpression2 expr) { - DCHECK(expr.IsCanonicalized()); - DCHECK_EQ(expr.DivideByGcd(), 1); - const bool negated = expr.NegateForCanonicalization(); - auto [it, inserted] = expr_to_index_.insert({expr, exprs_.size()}); - if (inserted) { - CHECK_LT(2 * exprs_.size() + 1, - std::numeric_limits::max()); - exprs_.push_back(expr); - } - const LinearExpression2Index positive_index(2 * it->second); - if (negated) { - return NegationOf(positive_index); - } else { - return positive_index; - } - } + LinearExpression2Index AddOrGet(LinearExpression2 expr); + + // Returns a never-changing index for the given linear expression if it is + // potentially non-trivial, otherwise returns kNoLinearExpression2Index. The + // expression must already be canonicalized and divided by its GCD. + LinearExpression2Index GetIndex(LinearExpression2 expr) const; + + LinearExpression2 GetExpression(LinearExpression2Index index) const; // Return all positive linear2 expressions that have a potentially non-trivial // bound. When calling this code it is often a good idea to check both the @@ -97,9 +89,45 @@ class Linear2WithPotentialNonTrivalBounds { return exprs_; } + // Return a list of all potentially non-trivial LinearExpression2Indexes + // containing a given variable. + absl::Span GetAllLinear2ContainingVariable( + IntegerVariable var) const; + + // Return a list of all potentially non-trivial LinearExpression2Indexes + // containing a given pair of variables. + absl::Span GetAllLinear2ContainingVariables( + IntegerVariable var1, IntegerVariable var2) const; + + // For a given variable `var`, return all linear expressions with both + // coefficients 1 that have a potentially non trivial upper bound. For + // convenience it also returns the other variable to cheaply build the + // linear2. Note that using negation one can also recover x + y >= lb and x - + // y <= ub. + absl::Span + GetAllLinear2ContainingVariableWithCoeffOne(IntegerVariable var) const { + if (var >= coeff_one_var_lookup_.size()) return {}; + return coeff_one_var_lookup_[var]; + } + private: - util_intops::StrongVector exprs_; + std::vector exprs_; absl::flat_hash_map expr_to_index_; + + // Lookup table to find all the LinearExpression2 with a given variable and + // having both coefficient 1. + util_intops::StrongVector> + coeff_one_var_lookup_; + + // Map to implement GetAllBoundsContainingVariable(). + absl::flat_hash_map> + var_to_bounds_; + // Map to implement GetAllBoundsContainingVariables(). + absl::flat_hash_map, + absl::InlinedVector> + var_pair_to_bounds_; }; // Simple "watcher" class that will be notified if a linear2 bound changed. It @@ -138,7 +166,13 @@ class RootLevelLinear2Bounds { linear2_watcher_(model->GetOrCreate()), shared_stats_(model->GetOrCreate()), non_trivial_bounds_( - model->GetOrCreate()) {} + model->GetOrCreate()), + cp_model_mapping_(model->GetOrCreate()), + shared_linear2_bounds_(model->Mutable()), + shared_linear2_bounds_id_( + shared_linear2_bounds_ == nullptr + ? 0 + : shared_linear2_bounds_->RegisterNewId(model->Name())) {} ~RootLevelLinear2Bounds(); @@ -147,16 +181,49 @@ class RootLevelLinear2Bounds { // Returns a pair saying whether the lower/upper bounds for this expr became // more restricted than what was currently stored. std::pair Add(LinearExpression2 expr, IntegerValue lb, - IntegerValue ub); + IntegerValue ub) { + const bool negated = expr.CanonicalizeAndUpdateBounds(lb, ub); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return {false, false}; + const LinearExpression2Index index = non_trivial_bounds_->AddOrGet(expr); + bool ub_changed = AddUpperBound(index, ub); + bool lb_changed = AddUpperBound(NegationOf(index), -lb); + if (negated) { + std::swap(lb_changed, ub_changed); + } + return {lb_changed, ub_changed}; + } + + bool AddUpperBound(LinearExpression2Index index, IntegerValue ub); // Same as above, but only update the upper bound. bool AddUpperBound(LinearExpression2 expr, IntegerValue ub) { - return Add(expr, kMinIntegerValue, ub).second; + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false; + const IntegerValue gcd = expr.DivideByGcd(); + ub = FloorRatio(ub, gcd); + return AddUpperBound(non_trivial_bounds_->AddOrGet(expr), ub); } - IntegerValue LevelZeroUpperBound(LinearExpression2 expr) const; + IntegerValue LevelZeroUpperBound(LinearExpression2 expr) const { + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { + return integer_trail_->LevelZeroUpperBound(expr); + } + const IntegerValue gcd = expr.DivideByGcd(); + const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr); + if (index == kNoLinearExpression2Index) { + return integer_trail_->LevelZeroUpperBound(expr); + } + return CapProdI(gcd, LevelZeroUpperBound(index)); + } - int64_t num_bounds() const { return root_level_relations_.num_bounds(); } + IntegerValue LevelZeroUpperBound(LinearExpression2Index index) const { + const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index); + // TODO(user): Remove the expression from the root_level_relations_ if + // the zero-level bound got more restrictive. + return std::min(integer_trail_->LevelZeroUpperBound(expr), + GetUpperBoundNoTrail(index)); + } // Return a list of (expr <= ub) sorted by expr. They are guaranteed to be // better than the trivial upper bound. @@ -183,11 +250,8 @@ class RootLevelLinear2Bounds { // For a given variable `var`, return all variables `other` so that // LinearExpression2(var, other, 1, 1) has a non trivial upper bound. // Note that using negation one can also recover x + y >= lb and x - y <= ub. - absl::Span GetVariablesInSimpleRelation( - IntegerVariable var) const { - if (var >= coeff_one_var_lookup_.size()) return {}; - return coeff_one_var_lookup_[var]; - } + std::vector GetVariablesInSimpleRelation( + IntegerVariable var) const; RelationStatus GetLevelZeroStatus(LinearExpression2 expr, IntegerValue lb, IntegerValue ub) const; @@ -197,47 +261,21 @@ class RootLevelLinear2Bounds { // behavior from LevelZeroUpperBound() that would return the implied // zero-level bound from the trail for trivial ones. `expr` must be // canonicalized and gcd-reduced. - IntegerValue GetUpperBoundNoTrail(LinearExpression2 expr) const; - - void AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const; + IntegerValue GetUpperBoundNoTrail(LinearExpression2Index index) const; private: IntegerTrail* integer_trail_; Linear2Watcher* linear2_watcher_; SharedStatistics* shared_stats_; Linear2WithPotentialNonTrivalBounds* non_trivial_bounds_; + CpModelMapping* cp_model_mapping_; + SharedLinear2Bounds* shared_linear2_bounds_; // Might be nullptr. - // Lookup table to find all the LinearExpression2 with a given variable and - // having both coefficient 1. - util_intops::StrongVector> - coeff_one_var_lookup_; + const int shared_linear2_bounds_id_; - // TODO(user): use data structures that consume less memory. A single - // std::vector and hash maps having the index as value - // could be enough. - absl::flat_hash_map< - IntegerVariable, - absl::InlinedVector< - std::tuple, 2>> - var_to_bounds_; - // Map to implement GetAllBoundsContainingVariables(). - absl::flat_hash_map< - std::pair, - absl::InlinedVector< - std::tuple, 1>> - var_pair_to_bounds_; - // Data structure to quickly update var_to_bounds_. Return the index where - // this linear expression appear in the vector for the first and second - // variable. - absl::flat_hash_map> - var_to_bounds_vector_index_; - absl::flat_hash_map var_pair_to_bounds_vector_index_; + util_intops::StrongVector + best_upper_bounds_; - // TODO(user): Also push them to a global shared repository after - // remapping IntegerVariable to proto indices. - BestBinaryRelationBounds root_level_relations_; int64_t num_updates_ = 0; }; @@ -338,7 +376,17 @@ class EnforcedLinear2Bounds : public ReversibleInterface { // If expr is not a proper linear2 expression (e.g. 0*x + y, y + y, y - y) it // will be ignored. void PushConditionalRelation(absl::Span enforcements, - LinearExpression2 expr, IntegerValue rhs); + LinearExpression2Index index, IntegerValue rhs); + + void PushConditionalRelation(absl::Span enforcements, + LinearExpression2 expr, IntegerValue rhs) { + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return; + const IntegerValue gcd = expr.DivideByGcd(); + rhs = FloorRatio(rhs, gcd); + return PushConditionalRelation(enforcements, + non_trivial_bounds_->AddOrGet(expr), rhs); + } // Called each time we change decision level. void SetLevel(int level) final; @@ -365,18 +413,13 @@ class EnforcedLinear2Bounds : public ReversibleInterface { // Low-level function that returns the upper bound if there is some enforced // relations only. Otherwise always returns kMaxIntegerValue. // `expr` must be canonicalized and gcd-reduced. - IntegerValue GetUpperBoundFromEnforced(LinearExpression2 expr) const; + IntegerValue GetUpperBoundFromEnforced(LinearExpression2Index index) const; void AddReasonForUpperBoundLowerThan( - LinearExpression2 expr, IntegerValue ub, + LinearExpression2Index index, IntegerValue ub, std::vector* literal_reason, std::vector* integer_reason) const; - // Note: might contain duplicate expressions. - void AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const; - private: void CreateLevelEntryIfNeeded(); @@ -395,13 +438,13 @@ class EnforcedLinear2Bounds : public ReversibleInterface { // TODO(user): this kind of reversible hash_map is already implemented in // other part of the code. Consolidate. struct ConditionalEntry { - ConditionalEntry(int p, IntegerValue r, LinearExpression2 k, + ConditionalEntry(int p, IntegerValue r, LinearExpression2Index k, absl::Span e) : prev_entry(p), rhs(r), key(k), enforcements(e.begin(), e.end()) {} int prev_entry; IntegerValue rhs; - LinearExpression2 key; + LinearExpression2Index key; absl::InlinedVector enforcements; }; std::vector conditional_stack_; @@ -409,7 +452,7 @@ class EnforcedLinear2Bounds : public ReversibleInterface { // This is always stored in the form (expr <= rhs). // The conditional relations contains indices in the conditional_stack_. - absl::flat_hash_map conditional_relations_; + util_intops::StrongVector conditional_relations_; // Store for each variable x, the variables y that appears alongside it in // lit => x + y <= ub. Note that conditional_var_lookup_ is updated on @@ -510,11 +553,6 @@ class Linear2BoundsFromLinear3 { // will replace it and returns true, otherwise it returns false. bool AddAffineUpperBound(LinearExpression2 expr, AffineExpression affine_ub); - // Warning, the order will not be deterministic. - void AppendAllExpressionContaining( - Bitset64::ConstView var_set, - std::vector* result) const; - // Most users should just use Linear2Bounds::UpperBound() instead. // // Returns the upper bound only if there is some relations coming from a @@ -601,7 +639,9 @@ class Linear2Bounds { : integer_trail_(model->GetOrCreate()), root_level_bounds_(model->GetOrCreate()), enforced_bounds_(model->GetOrCreate()), - linear3_bounds_(model->GetOrCreate()) {} + linear3_bounds_(model->GetOrCreate()), + non_trivial_bounds_( + model->GetOrCreate()) {} // Returns the best known upper-bound of the given LinearExpression2 at the // current decision level. If its explanation is needed, it can be queried @@ -616,31 +656,12 @@ class Linear2Bounds { // don't want the trivial bounds. IntegerValue NonTrivialUpperBoundForGcd1(LinearExpression2 expr) const; - // Returns all known expressions with potentially non-trivial bounds that - // involves two variable whose positive version is marked in 'vars'. - absl::Span - GetAllExpressionsWithPotentialNonTrivialBounds( - Bitset64::ConstView var_set) const; - - // Returns a temporary bitset, cleared, and resized for all existing - // variables. - // - // If we have many class calling - // GetAllExpressionsWithPotentialNonTrivialBounds() it is important that not - // all of them have a O(num_variables) vector when the same one can be used. - SparseBitset* GetTemporyClearedAndResizedBitset() { - tmp_bitset_.ClearAndResize(integer_trail_->NumIntegerVariables()); - return &tmp_bitset_; - } - private: IntegerTrail* integer_trail_; RootLevelLinear2Bounds* root_level_bounds_; EnforcedLinear2Bounds* enforced_bounds_; Linear2BoundsFromLinear3* linear3_bounds_; - - mutable std::vector tmp_expressions_; - SparseBitset tmp_bitset_; + Linear2WithPotentialNonTrivalBounds* non_trivial_bounds_; }; // Detects if at least one of a subset of linear of size 2 or 1, touching the @@ -1000,6 +1021,58 @@ inline std::function ConditionalLowerOrEqualWithOffset( }; } +inline LinearExpression2Index Linear2WithPotentialNonTrivalBounds::GetIndex( + LinearExpression2 expr) const { + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); + const bool negated = expr.NegateForCanonicalization(); + auto it = expr_to_index_.find(expr); + if (it == expr_to_index_.end()) return kNoLinearExpression2Index; + + const LinearExpression2Index positive_index(2 * it->second); + if (negated) { + return NegationOf(positive_index); + } else { + return positive_index; + } +} + +inline LinearExpression2 Linear2WithPotentialNonTrivalBounds::GetExpression( + LinearExpression2Index index) const { + DCHECK_NE(index, kNoLinearExpression2Index); + const int lookup_index = index.value() / 2; + DCHECK_LT(lookup_index, exprs_.size()); + if (Linear2IsPositive(index)) { + return exprs_[lookup_index]; + } else { + LinearExpression2 result = exprs_[lookup_index]; + result.Negate(); + return result; + } +} + +inline absl::Span +Linear2WithPotentialNonTrivalBounds::GetAllLinear2ContainingVariable( + IntegerVariable var) const { + const IntegerVariable positive_var = PositiveVariable(var); + auto it = var_to_bounds_.find(positive_var); + if (it == var_to_bounds_.end()) return {}; + return it->second; +} + +inline absl::Span +Linear2WithPotentialNonTrivalBounds::GetAllLinear2ContainingVariables( + IntegerVariable var1, IntegerVariable var2) const { + IntegerVariable positive_var1 = PositiveVariable(var1); + IntegerVariable positive_var2 = PositiveVariable(var2); + if (positive_var1 > positive_var2) { + std::swap(positive_var1, positive_var2); + } + auto it = var_pair_to_bounds_.find({positive_var1, positive_var2}); + if (it == var_pair_to_bounds_.end()) return {}; + return it->second; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/precedences_test.cc b/ortools/sat/precedences_test.cc index 0f911b9144..159469f659 100644 --- a/ortools/sat/precedences_test.cc +++ b/ortools/sat/precedences_test.cc @@ -190,6 +190,8 @@ TEST(EnforcedLinear2BoundsTest, ConditionalRelations) { auto* lin2_bounds = model.GetOrCreate(); auto* integer_trail = model.GetOrCreate(); auto* precedences = model.GetOrCreate(); + auto* non_trivial_bounds = + model.GetOrCreate(); const std::vector vars = AddVariables(integer_trail); const Literal l(model.Add(NewBooleanVariable()), true); @@ -200,26 +202,25 @@ TEST(EnforcedLinear2BoundsTest, ConditionalRelations) { precedences->PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 15); precedences->PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 20); + LinearExpression2 expr_a_plus_b = + LinearExpression2::Difference(a, NegationOf(b)); + expr_a_plus_b.SimpleCanonicalization(); // We only keep the best one. - EXPECT_EQ( - lin2_bounds->UpperBound(LinearExpression2::Difference(a, NegationOf(b))), - 15); + EXPECT_EQ(lin2_bounds->UpperBound(expr_a_plus_b), 15); std::vector literal_reason; std::vector integer_reason; precedences->AddReasonForUpperBoundLowerThan( - LinearExpression2::Difference(a, NegationOf(b)), 15, &literal_reason, + non_trivial_bounds->AddOrGet(expr_a_plus_b), 15, &literal_reason, &integer_reason); EXPECT_THAT(literal_reason, ElementsAre(l.Negated())); // Backtrack works. EXPECT_TRUE(sat_solver->ResetToLevelZero()); - EXPECT_EQ( - lin2_bounds->UpperBound(LinearExpression2::Difference(a, NegationOf(b))), - 200); + EXPECT_EQ(lin2_bounds->UpperBound(expr_a_plus_b), 200); literal_reason.clear(); integer_reason.clear(); precedences->AddReasonForUpperBoundLowerThan( - LinearExpression2::Difference(a, NegationOf(b)), kMaxIntegerValue, + non_trivial_bounds->AddOrGet(expr_a_plus_b), kMaxIntegerValue, &literal_reason, &integer_reason); EXPECT_THAT(literal_reason, IsEmpty()); } diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index b013e7d314..60901fc1c0 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -24,7 +24,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 326 +// NEXT TAG: 327 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -703,6 +703,13 @@ message SatParameters { // Allows sharing of the bounds of modified variables at level 0. optional bool share_level_zero_bounds = 114 [default = true]; + // Allows sharing of the bounds on linear2 discovered at level 0. This is + // mainly interesting on scheduling type of problems when we branch on + // precedences. + // + // Warning: This currently non-deterministic. + optional bool share_linear2_bounds = 326 [default = false]; + // Allows sharing of new learned binary clause between workers. optional bool share_binary_clauses = 203 [default = true]; diff --git a/ortools/sat/sat_runner.cc b/ortools/sat/sat_runner.cc index c31a0e2b27..c1dceb038b 100644 --- a/ortools/sat/sat_runner.cc +++ b/ortools/sat/sat_runner.cc @@ -16,9 +16,11 @@ #include #include #include +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/flags/flag.h" #include "absl/flags/parse.h" #include "absl/flags/usage.h" @@ -30,6 +32,8 @@ #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" #include "ortools/base/helpers.h" @@ -45,6 +49,7 @@ #include "ortools/sat/synchronization.h" #include "ortools/util/file_util.h" #include "ortools/util/logging.h" +#include "ortools/util/sigint.h" #include "ortools/util/sorted_interval_list.h" ABSL_FLAG( @@ -102,8 +107,69 @@ std::string ExtractName(absl::string_view full_filename) { return filename; } -void LogInPbCompetitionFormat(int num_variables, bool has_objective, - Model* model, SatParameters* parameters) { +class LastSolutionPrinter { + public: + // Note that is prints the solution in the PB competition format. + void MaybePrintLastSolution() { + absl::MutexLock lock(&mutex_); + if (last_solution_printed_) return; + last_solution_printed_ = true; + + if (last_solution_.empty()) { + std::cout << "s UNKNOWN" << std::endl; + } else { + std::cout << "s SATISFIABLE" << std::endl; + std::string line; + for (int i = 0; i < num_variables_; ++i) { + if (last_solution_[i]) { + absl::StrAppend(&line, "x", i + 1, " "); + } else { + absl::StrAppend(&line, "-x", i + 1, " "); + } + if (line.size() >= 75) { + std::cout << "v " << line << std::endl; + line.clear(); + } + } + if (!line.empty()) { + std::cout << "v " << line << std::endl; + } + } + } + + void set_num_variables(int num_variables) { num_variables_ = num_variables; } + + void set_last_solution(absl::Span solution) { + absl::MutexLock lock(&mutex_); + if (last_solution_printed_) return; + last_solution_.assign(solution.begin(), solution.end()); + } + + // Returns false if the solution has already been printed, else mark it as + // printed by caller code. + bool mark_last_solution_printed() { + const absl::MutexLock lock(&mutex_); + if (last_solution_printed_) { + return false; + } + last_solution_printed_ = true; + return true; + } + + private: + int num_variables_ = 0; + std::vector last_solution_ ABSL_GUARDED_BY(mutex_); + bool last_solution_printed_ ABSL_GUARDED_BY(mutex_) = false; + absl::Mutex mutex_; +}; + +void LogInPbCompetitionFormat( + int num_variables, bool has_objective, Model* model, + SatParameters* parameters, + std::shared_ptr last_solution_printer) { + CHECK(last_solution_printer != nullptr); + last_solution_printer->set_num_variables(num_variables); + const auto log_callback = [](const std::string& multi_line_input) { if (multi_line_input.empty()) { std::cout << "c" << std::endl; @@ -118,55 +184,60 @@ void LogInPbCompetitionFormat(int num_variables, bool has_objective, model->GetOrCreate()->AddInfoLoggingCallback(log_callback); parameters->set_log_to_stdout(false); - const auto response_callback = [](const CpSolverResponse& r) { + const auto response_callback = [last_solution_printer]( + const CpSolverResponse& r) { std::cout << "o " << static_cast(r.objective_value()) << std::endl; + last_solution_printer->set_last_solution(r.solution()); }; model->Add(NewFeasibleSolutionObserver(response_callback)); - const auto final_response_callback = [num_variables, - has_objective](CpSolverResponse* r) { - switch (r->status()) { - case CpSolverStatus::OPTIMAL: - if (has_objective) { - std::cout << "s OPTIMUM FOUND " << std::endl; - } else { - std::cout << "s SATISFIABLE" << std::endl; + const auto final_response_callback = + [num_variables, has_objective, + last_solution_printer](CpSolverResponse* r) { + if (!last_solution_printer->mark_last_solution_printed()) return; + + switch (r->status()) { + case CpSolverStatus::OPTIMAL: + if (has_objective) { + std::cout << "s OPTIMUM FOUND " << std::endl; + } else { + std::cout << "s SATISFIABLE" << std::endl; + } + break; + case CpSolverStatus::FEASIBLE: + std::cout << "s SATISFIABLE" << std::endl; + break; + case CpSolverStatus::INFEASIBLE: + std::cout << "s UNSATISFIABLE" << std::endl; + break; + case CpSolverStatus::MODEL_INVALID: + std::cout << "s UNSUPPORTED" << std::endl; + break; + case CpSolverStatus::UNKNOWN: + std::cout << "s UNKNOWN" << std::endl; + break; + default: + break; } - break; - case CpSolverStatus::FEASIBLE: - std::cout << "s SATISFIABLE" << std::endl; - break; - case CpSolverStatus::INFEASIBLE: - std::cout << "s UNSATISFIABLE" << std::endl; - break; - case CpSolverStatus::MODEL_INVALID: - std::cout << "s UNSUPPORTED" << std::endl; - break; - case CpSolverStatus::UNKNOWN: - std::cout << "s UNKNOWN" << std::endl; - break; - default: - break; - } - if (r->status() == CpSolverStatus::OPTIMAL || - r->status() == CpSolverStatus::FEASIBLE) { - std::string line; - for (int i = 0; i < num_variables; ++i) { - if (r->solution(i)) { - absl::StrAppend(&line, "x", i + 1, " "); - } else { - absl::StrAppend(&line, "-x", i + 1, " "); + if (r->status() == CpSolverStatus::OPTIMAL || + r->status() == CpSolverStatus::FEASIBLE) { + std::string line; + for (int i = 0; i < num_variables; ++i) { + if (r->solution(i)) { + absl::StrAppend(&line, "x", i + 1, " "); + } else { + absl::StrAppend(&line, "-x", i + 1, " "); + } + if (line.size() >= 75) { + std::cout << "v " << line << std::endl; + line.clear(); + } + } + if (!line.empty()) { + std::cout << "v " << line << std::endl; + } } - if (line.size() >= 75) { - std::cout << "v " << line << std::endl; - line.clear(); - } - } - if (!line.empty()) { - std::cout << "v " << line << std::endl; - } - } - }; + }; model->GetOrCreate()->AddFinalResponsePostprocessor( final_response_callback); } @@ -186,7 +257,8 @@ void SetInterleavedWorkers(SatParameters* parameters) { bool LoadProblem(const std::string& filename, absl::string_view hint_file, absl::string_view domain_file, CpModelProto* cp_model, - Model* model, SatParameters* parameters) { + Model* model, SatParameters* parameters, + std::shared_ptr last_solution_printer) { if (absl::EndsWith(filename, ".opb") || absl::EndsWith(filename, ".opb.bz2") || absl::EndsWith(filename, ".opb.gz") || absl::EndsWith(filename, ".wbo") || @@ -217,7 +289,7 @@ bool LoadProblem(const std::string& filename, absl::string_view hint_file, const int num_variables = reader.model_is_supported() ? reader.num_variables() : 1; LogInPbCompetitionFormat(num_variables, cp_model->has_objective(), model, - parameters); + parameters, last_solution_printer); } if (absl::GetFlag(FLAGS_force_interleave_search)) { SetInterleavedWorkers(parameters); @@ -310,9 +382,13 @@ int Run() { google::protobuf::Arena arena; CpModelProto* cp_model = google::protobuf::Arena::Create(&arena); + std::shared_ptr last_solution_printer; + if (absl::GetFlag(FLAGS_competition_mode)) { + last_solution_printer = std::make_shared(); + } if (!LoadProblem(absl::GetFlag(FLAGS_input), absl::GetFlag(FLAGS_hint_file), absl::GetFlag(FLAGS_domain_file), cp_model, &model, - ¶meters)) { + ¶meters, last_solution_printer)) { if (!absl::GetFlag(FLAGS_competition_mode)) { LOG(FATAL) << "Cannot load file '" << absl::GetFlag(FLAGS_input) << "'."; } @@ -329,6 +405,14 @@ int Run() { FingerprintRepeatedField(r.solution(), kDefaultFingerprintSeed)); })); } + + if (absl::GetFlag(FLAGS_competition_mode)) { + model.GetOrCreate()->Register([last_solution_printer]() { + last_solution_printer->MaybePrintLastSolution(); + exit(EXIT_SUCCESS); + }); + } + const CpSolverResponse response = SolveCpModel(*cp_model, &model); if (!absl::GetFlag(FLAGS_output).empty()) { diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 0c1ed51803..18f37e7cfb 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -1386,14 +1386,27 @@ int UniqueClauseStream::NumLiteralsOfSize(int size) const { SharedClausesManager::SharedClausesManager(bool always_synchronize) : always_synchronize_(always_synchronize) {} -int SharedClausesManager::RegisterNewId(bool may_terminate_early) { +int SharedClausesManager::RegisterNewId(absl::string_view worker_name, + bool may_terminate_early) { absl::MutexLock mutex_lock(&mutex_); num_full_workers_ += may_terminate_early ? 0 : 1; const int id = id_to_last_processed_binary_clause_.size(); id_to_last_processed_binary_clause_.resize(id + 1, 0); id_to_last_returned_batch_.resize(id + 1, -1); id_to_last_finished_batch_.resize(id + 1, -1); - id_to_clauses_exported_.resize(id + 1, 0); + id_to_num_exported_.resize(id + 1, 0); + id_to_worker_name_.resize(id + 1); + id_to_worker_name_[id] = worker_name; + return id; +} + +int SharedLinear2Bounds::RegisterNewId(std::string worker_name) { + absl::MutexLock mutex_lock(&mutex_); + const int id = id_to_worker_name_.size(); + + id_to_stats_.resize(id + 1); + id_to_worker_name_.resize(id + 1); + id_to_worker_name_[id] = worker_name; return id; } @@ -1401,12 +1414,6 @@ bool SharedClausesManager::ShouldReadBatch(int reader_id, int writer_id) { return reader_id != writer_id; } -void SharedClausesManager::SetWorkerNameForId(int id, - absl::string_view worker_name) { - absl::MutexLock mutex_lock(&mutex_); - id_to_worker_name_[id] = worker_name; -} - void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { if (lit2 < lit1) std::swap(lit1, lit2); const auto p = std::make_pair(lit1, lit2); @@ -1416,7 +1423,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { if (inserted) { added_binary_clauses_.push_back(p); if (always_synchronize_) ++last_visible_binary_clause_; - id_to_clauses_exported_[id]++; + id_to_num_exported_[id]++; // Small optim. If the worker is already up to date with clauses to import, // we can mark this new clause as already seen. @@ -1429,7 +1436,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { void SharedClausesManager::AddBatch(int id, CompactVectorVector batch) { absl::MutexLock mutex_lock(&mutex_); - id_to_clauses_exported_[id] += batch.size(); + id_to_num_exported_[id] += batch.size(); pending_batches_.push_back(std::move(batch)); } @@ -1463,16 +1470,44 @@ void SharedClausesManager::GetUnseenBinaryClauses( void SharedClausesManager::LogStatistics(SolverLogger* logger) { absl::MutexLock mutex_lock(&mutex_); - absl::btree_map name_to_clauses; - for (int id = 0; id < id_to_clauses_exported_.size(); ++id) { - if (id_to_clauses_exported_[id] == 0) continue; - name_to_clauses[id_to_worker_name_[id]] = id_to_clauses_exported_[id]; + absl::btree_map name_to_table_line; + for (int id = 0; id < id_to_num_exported_.size(); ++id) { + if (id_to_num_exported_[id] == 0) continue; + name_to_table_line[id_to_worker_name_[id]] = id_to_num_exported_[id]; } - if (!name_to_clauses.empty()) { + if (!name_to_table_line.empty()) { std::vector> table; table.push_back({"Clauses shared", "Num"}); - for (const auto& entry : name_to_clauses) { - table.push_back({FormatName(entry.first), FormatCounter(entry.second)}); + for (const auto& [name, count] : name_to_table_line) { + table.push_back({FormatName(name), FormatCounter(count)}); + } + SOLVER_LOG(logger, FormatTable(table)); + } +} + +// TODO(user): Add some library to simplify this "transposition". Ideally we +// could merge small table with few columns. I am thinking list (row_name, +// col_name, count) + function that create table? +void SharedLinear2Bounds::LogStatistics(SolverLogger* logger) { + absl::MutexLock mutex_lock(&mutex_); + absl::btree_map name_to_table_line; + for (int id = 0; id < id_to_stats_.size(); ++id) { + const Stats stats = id_to_stats_[id]; + if (!stats.empty()) { + name_to_table_line[id_to_worker_name_[id]] = stats; + } + } + for (int import_id = 0; import_id < import_id_to_index_.size(); ++import_id) { + name_to_table_line[import_id_to_name_[import_id]].num_imported = + import_id_to_num_imported_[import_id]; + } + if (!name_to_table_line.empty()) { + std::vector> table; + table.push_back({"Linear2 shared", "New", "Updated", "Imported"}); + for (const auto& [name, stats] : name_to_table_line) { + table.push_back({FormatName(name), FormatCounter(stats.num_new), + FormatCounter(stats.num_update), + FormatCounter(stats.num_imported)}); } SOLVER_LOG(logger, FormatTable(table)); } @@ -1522,6 +1557,69 @@ void SharedClausesManager::Synchronize() { } } +void SharedLinear2Bounds::Add(int id, Key expr, IntegerValue lb, + IntegerValue ub) { + DCHECK(expr.IsCanonicalized()); + + absl::MutexLock mutex_lock(&mutex_); + auto [it, inserted] = shared_bounds_.insert({expr, {lb, ub}}); + if (inserted) { + // It is new. + id_to_stats_[id].num_new++; + newly_updated_keys_.push_back(expr); + } else { + // Update the individual bounds if the new ones are better. + auto& bounds = it->second; + const bool update_lb = lb > bounds.first; + if (update_lb) bounds.first = lb; + const bool update_ub = ub < bounds.second; + if (update_ub) bounds.second = ub; + if (update_lb || update_ub) { + id_to_stats_[id].num_update++; + newly_updated_keys_.push_back(expr); + } + } +} + +int SharedLinear2Bounds::RegisterNewImportId(std::string name) { + absl::MutexLock mutex_lock(&mutex_); + const int import_id = import_id_to_index_.size(); + import_id_to_name_.push_back(name); + import_id_to_index_.push_back(0); + import_id_to_num_imported_.push_back(0); + return import_id; +} + +std::vector< + std::pair>> +SharedLinear2Bounds::NewlyUpdatedBounds(int import_id) { + std::vector>> result; + + absl::MutexLock mutex_lock(&mutex_); + MaybeCompressNewlyUpdateKeys(); + const int size = newly_updated_keys_.size(); + for (int i = import_id_to_index_[import_id]; i < size; ++i) { + const auto& key = newly_updated_keys_[i]; + result.push_back({key, shared_bounds_[key]}); + } + import_id_to_index_[import_id] = size; + return result; +} + +void SharedLinear2Bounds::MaybeCompressNewlyUpdateKeys() { + int min_index = 0; + for (const int index : import_id_to_index_) { + min_index = std::min(index, min_index); + } + if (min_index == 0) return; + + newly_updated_keys_.erase(newly_updated_keys_.begin(), + newly_updated_keys_.begin() + min_index); + for (int& index_ref : import_id_to_index_) { + index_ref -= min_index; + } +} + void SharedStatistics::AddStats( absl::Span> stats) { absl::MutexLock mutex_lock(&mutex_); diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index 38722c2264..a9cd377fdb 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -848,8 +848,7 @@ class SharedClausesManager { std::vector>* new_clauses); // Ids are used to identify which worker is exporting/importing clauses. - int RegisterNewId(bool may_terminate_early); - void SetWorkerNameForId(int id, absl::string_view worker_name); + int RegisterNewId(absl::string_view worker_name, bool may_terminate_early); // Search statistics. void LogStatistics(SolverLogger* logger); @@ -893,8 +892,100 @@ class SharedClausesManager { const bool always_synchronize_ = true; // Stats: - std::vector id_to_clauses_exported_; - absl::flat_hash_map id_to_worker_name_; + std::vector id_to_num_exported_ ABSL_GUARDED_BY(mutex_); + std::vector id_to_num_updated_ ABSL_GUARDED_BY(mutex_); + std::vector id_to_worker_name_ ABSL_GUARDED_BY(mutex_); +}; + +// A class that allows to exchange root level bounds on linear2. +// +// TODO(user): Add Synchronize() support and only publish new bounds when this +// is called. +class SharedLinear2Bounds { + public: + int RegisterNewId(std::string worker_name); + void LogStatistics(SolverLogger* logger); + + // This should only contain canonicalized expression. + // See the code for IsCanonicalized() for the definition. + struct Key { + int vars[2]; + IntegerValue coeffs[2]; + + bool IsCanonicalized() { + return coeffs[0] > 0 && coeffs[1] != 0 && vars[0] < vars[1] && + std::gcd(coeffs[0].value(), coeffs[1].value()) == 1; + } + + bool operator==(const Key& o) const { + return vars[0] == o.vars[0] && vars[1] == o.vars[1] && + coeffs[0] == o.coeffs[0] && coeffs[1] == o.coeffs[1]; + } + + template + friend H AbslHashValue(H h, const Key& k) { + return H::combine(std::move(h), k.vars[0], k.vars[1], k.coeffs[0], + k.coeffs[1]); + } + }; + + // Exports new bounds on the given expr (should be canonicalized). + void Add(int id, Key expr, IntegerValue lb, IntegerValue ub); + + // This is called less often, and maybe not every-worker that exports want to + // export, so we use a separate id space. Because we rely on hash map to + // check if a bound is new, it is not such a big deal that a worker re-read + // once the bounds it exported. + int RegisterNewImportId(std::string name); + + // Returns the linear2 and their bounds. + // We only return changes since the last call with the same id. + std::vector>> + NewlyUpdatedBounds(int import_id); + + // This is not filled by NewlyUpdatedBounds() because we want to track the + // bounds that were not already known by the worker at the time of the import, + // and we don't have this information here. + void NotifyNumImported(int import_id, int num) { + absl::MutexLock mutex_lock(&mutex_); + import_id_to_num_imported_[import_id] += num; + } + + private: + void MaybeCompressNewlyUpdateKeys() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + absl::Mutex mutex_; + + // The best known bounds for each key. + absl::flat_hash_map> shared_bounds_ + ABSL_GUARDED_BY(mutex_); + + // Ever growing list of updated position in shared_bounds_. + // Note that we do reduce it in MaybeCompressNewlyUpdateKeys(), but that + // requires all registered workers to have at least imported some bounds. + // + // TODO(user): use indirect addressing so that newly_updated_keys_ can just + // deal with indices, and it is a bit tighter memory wise? We also avoid + // hash-lookups on NewlyUpdatedBounds(). But since this is only called at + // level zero on new bounds, I don't think we care. + std::vector newly_updated_keys_; + + // For import. + std::vector import_id_to_name_ ABSL_GUARDED_BY(mutex_); + std::vector import_id_to_index_ ABSL_GUARDED_BY(mutex_); + std::vector import_id_to_num_imported_ ABSL_GUARDED_BY(mutex_); + + // Just for reporting at the end of the solve. + struct Stats { + int64_t num_new = 0; + int64_t num_update = 0; + int64_t num_imported = 0; // Copy of import_id_to_num_imported_. + bool empty() const { + return num_new == 0 && num_update == 0 && num_imported == 0; + } + }; + std::vector id_to_stats_ ABSL_GUARDED_BY(mutex_); + std::vector id_to_worker_name_ ABSL_GUARDED_BY(mutex_); }; // Simple class to add statistics by name and print them at the end. diff --git a/ortools/sat/synchronization_test.cc b/ortools/sat/synchronization_test.cc index 00dd4a2550..1ab19d6cbc 100644 --- a/ortools/sat/synchronization_test.cc +++ b/ortools/sat/synchronization_test.cc @@ -834,8 +834,8 @@ TEST(SharedResponseManagerTest, Callback) { TEST(SharedClausesManagerTest, SyncApi) { SharedClausesManager manager(/*always_synchronize=*/true); - EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + EXPECT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + EXPECT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); manager.AddBinaryClause(/*id=*/0, 1, 2); std::vector> new_clauses; @@ -922,8 +922,8 @@ TEST(UniqueClauseStreamTest, DropsClauses) { TEST(SharedClausesManagerTest, NonSyncApi) { SharedClausesManager manager(/*always_synchronize=*/false); - EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + EXPECT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + EXPECT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); manager.AddBinaryClause(/*id=*/0, 1, 2); std::vector> new_clauses; @@ -971,8 +971,8 @@ TEST(SharedClausesManagerTest, NonSyncApi) { TEST(SharedClausesManagerTest, ShareGlueClauses) { SharedClausesManager manager(/*always_synchronize=*/true); - ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); UniqueClauseStream stream0; UniqueClauseStream stream1; // Add a bunch of clauses that will be skipped batch. @@ -999,8 +999,8 @@ TEST(SharedClausesManagerTest, ShareGlueClauses) { TEST(SharedClausesManagerTest, LbdThresholdIncrease) { SharedClausesManager manager(/*always_synchronize=*/true); - ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); UniqueClauseStream stream0; UniqueClauseStream stream1; const int kExpectedClauses = UniqueClauseStream::kMaxLiteralsPerBatch / 5; @@ -1027,8 +1027,8 @@ TEST(SharedClausesManagerTest, LbdThresholdIncrease) { TEST(SharedClausesManagerTest, LbdThresholdDecrease) { SharedClausesManager manager(/*always_synchronize=*/true); - ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); UniqueClauseStream stream0; UniqueClauseStream stream1; diff --git a/ortools/util/sigint.cc b/ortools/util/sigint.cc index 601f4983cc..bd4f40cfac 100644 --- a/ortools/util/sigint.cc +++ b/ortools/util/sigint.cc @@ -23,29 +23,47 @@ namespace operations_research { void SigintHandler::Register(const std::function& f) { handler_ = [this, f]() -> void { - const int num_sigint_calls = ++num_sigint_calls_; - if (num_sigint_calls < 3) { + const int num_calls = ++num_calls_; + if (num_calls < 3) { LOG(INFO) - << "^C pressed " << num_sigint_calls << " times. " + << "^C pressed " << num_calls << " times. " << "Interrupting the solver. Press 3 times to force termination."; - if (num_sigint_calls == 1) f(); - } else if (num_sigint_calls == 3) { + if (num_calls == 1) f(); + } else if (num_calls == 3) { LOG(INFO) << "^C pressed 3 times. Forcing termination."; exit(EXIT_FAILURE); } else { // Another thread is already running exit(), do nothing. } }; - signal(SIGINT, &ControlCHandler); + signal(SIGINT, &SigHandler); } // This method will be called by the system after the SIGINT signal. // The parameter is the signal received. -void SigintHandler::ControlCHandler(int sig) { handler_(); } +void SigintHandler::SigHandler(int) { handler_(); } -// Unregister the SIGINT handler. -SigintHandler::~SigintHandler() { signal(SIGINT, SIG_DFL); } +// Unregister the signal handlers. +SigintHandler::~SigintHandler() { + if (handler_ != nullptr) signal(SIGINT, SIG_DFL); +} thread_local std::function SigintHandler::handler_; +void SigtermHandler::Register(const std::function& f) { + handler_ = [f]() -> void { f(); }; + signal(SIGTERM, &SigHandler); +} + +// This method will be called by the system after the SIGTERM signal. +// The parameter is the signal received. +void SigtermHandler::SigHandler(int) { handler_(); } + +// Unregister the signal handlers. +SigtermHandler::~SigtermHandler() { + if (handler_ != nullptr) signal(SIGTERM, SIG_DFL); +} + +thread_local std::function SigtermHandler::handler_; + } // namespace operations_research diff --git a/ortools/util/sigint.h b/ortools/util/sigint.h index 7b3098033e..1d9fcd1b81 100644 --- a/ortools/util/sigint.h +++ b/ortools/util/sigint.h @@ -21,7 +21,7 @@ namespace operations_research { class SigintHandler { public: - SigintHandler() {} + SigintHandler() = default; ~SigintHandler(); // Catches ^C and call f() the first time this happen. If ^C is pressed 3 @@ -29,9 +29,23 @@ class SigintHandler { void Register(const std::function& f); private: - static void ControlCHandler(int s); + std::atomic num_calls_ = 0; - std::atomic num_sigint_calls_ = 0; + static void SigHandler(int s); + thread_local static std::function handler_; +}; + +class SigtermHandler { + public: + SigtermHandler() = default; + ~SigtermHandler(); + + // Catches SIGTERM and call f(). It is recommended that f() calls exit() to + // terminate the program. + void Register(const std::function& f); + + private: + static void SigHandler(int s); thread_local static std::function handler_; }; diff --git a/ortools/util/sorted_interval_list.h b/ortools/util/sorted_interval_list.h index f07dca7c71..fb62e30d27 100644 --- a/ortools/util/sorted_interval_list.h +++ b/ortools/util/sorted_interval_list.h @@ -724,7 +724,9 @@ class ClosedInterval::Iterator { // arithmetic. uint64_t current_; }; - +#if __cplusplus >= 202002L +static_assert(std::input_iterator); +#endif // begin()/end() are required for iteration over ClosedInterval in a range for // loop. inline ClosedInterval::Iterator begin(ClosedInterval interval) {