From d432627bbc50b9e69951192474a3fe8def093589 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 24 Mar 2025 04:53:51 -0700 Subject: [PATCH] cleanups --- examples/cpp/nqueens.cc | 6 ++++- ortools/sat/cp_model_presolve.cc | 38 +++++++++++++++++++++----------- ortools/sat/cp_model_solver.h | 4 ++-- ortools/sat/docs/solver.md | 6 +---- ortools/sat/model.h | 1 + 5 files changed, 34 insertions(+), 21 deletions(-) diff --git a/examples/cpp/nqueens.cc b/examples/cpp/nqueens.cc index 9bb8dee420..a6ec13fa00 100644 --- a/examples/cpp/nqueens.cc +++ b/examples/cpp/nqueens.cc @@ -29,6 +29,7 @@ #include "ortools/base/logging.h" #include "ortools/base/map_util.h" #include "ortools/base/types.h" +#include "ortools/constraint_solver/constraint_solver.h" #include "ortools/constraint_solver/constraint_solveri.h" ABSL_FLAG(bool, print, false, "If true, print one of the solution."); @@ -39,7 +40,6 @@ ABSL_FLAG( int, size, 0, "Size of the problem. If equal to 0, will test several increasing sizes."); ABSL_FLAG(bool, use_symmetry, false, "Use Symmetry Breaking methods"); -ABSL_DECLARE_FLAG(bool, cp_disable_solve); static const int kNumSolutions[] = { 1, 0, 0, 2, 10, 4, 40, 92, 352, 724, 2680, 14200, 73712, 365596, 2279184}; @@ -182,10 +182,14 @@ void CheckNumberOfSolutions(int size, int num_solutions) { if (absl::GetFlag(FLAGS_use_symmetry)) { if (size - 1 < kKnownUniqueSolutions) { CHECK_EQ(num_solutions, kNumUniqueSolutions[size - 1]); + } else if (!absl::GetFlag(FLAGS_cp_disable_solve)) { + CHECK_GT(num_solutions, 0); } } else { if (size - 1 < kKnownSolutions) { CHECK_EQ(num_solutions, kNumSolutions[size - 1]); + } else if (!absl::GetFlag(FLAGS_cp_disable_solve)) { + CHECK_GT(num_solutions, 0); } } } diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index c900f426fe..c97447e9dd 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -5457,19 +5457,26 @@ bool CpModelPresolver::PresolveTable(ConstraintProto* ct) { namespace { // A container that is valid if only one value was added. -struct UniqueNonNegativeValue { - int index = -1; - - void Add(int new_index) { - DCHECK_GE(index, 0); - if (index == -1) { - index = new_index; +class UniqueNonNegativeValue { + public: + void Add(int value) { + DCHECK_GE(value, 0); + if (value_ == -1) { + value_ = value; } else { - index = -2; + value_ = -2; } } - bool IsValid() const { return index >= 0; } + bool HasUniqueValue() const { return value_ >= 0; } + + int64_t value() const { + DCHECK(HasUniqueValue()); + return value_; + } + + private: + int value_ = -1; }; } // namespace @@ -5492,7 +5499,7 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) { return RemoveConstraint(ct); } if (size == 1) { - context_->UpdateRuleStats("all_diff: only one expression"); + context_->UpdateRuleStats("all_diff: one expression"); return RemoveConstraint(ct); } @@ -5530,7 +5537,7 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) { } } if (propagated) { - context_->UpdateRuleStats("all_diff: propagate fixed values"); + context_->UpdateRuleStats("all_diff: propagate fixed expressions"); } } @@ -5610,9 +5617,10 @@ bool CpModelPresolver::PresolveAllDiff(ConstraintProto* ct) { bool propagated = false; for (const auto& [value, unique_index] : value_to_index) { - if (!unique_index.IsValid()) continue; + if (!unique_index.HasUniqueValue()) continue; - const LinearExpressionProto& expr = all_diff.exprs(unique_index.index); + const LinearExpressionProto& expr = + all_diff.exprs(unique_index.value()); if (!context_->IntersectDomainWith(expr, Domain(value), &propagated)) { return true; } @@ -7762,6 +7770,8 @@ void CpModelPresolver::Probe() { return (void)context_->NotifyThatModelIsUnsat("during probing"); } + time_limit_->ResetHistory(); + // Update the presolve context with fixed Boolean variables. int num_fixed = 0; CHECK_EQ(sat_solver->CurrentDecisionLevel(), 0); @@ -8694,6 +8704,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() { // We reuse the max-clique code from sat. Model local_model; local_model.GetOrCreate()->Resize(num_constraints); + local_model.GetOrCreate()->MergeWithGlobalTimeLimit(time_limit_); auto* graph = local_model.GetOrCreate(); graph->Resize(num_constraints); for (const std::vector& clique : cliques) { @@ -8730,6 +8741,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() { new_num_intervals, " intervals)."); context_->UpdateRuleStats("no_overlap: merged constraints"); } + time_limit_->ResetHistory(); } // TODO(user): Should we take into account the exactly_one constraints? note diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h index 9079b1a10b..b4ab8eec2a 100644 --- a/ortools/sat/cp_model_solver.h +++ b/ortools/sat/cp_model_solver.h @@ -128,8 +128,8 @@ std::function NewSatParameters( std::function NewSatParameters( const SatParameters& parameters); -// Stops the current search. -void StopSearch(Model*); +/// Stops the current search. +void StopSearch(Model* model); // TODO(user): Clean this up. /// Solves a CpModelProto without any processing. Only used for unit tests. diff --git a/ortools/sat/docs/solver.md b/ortools/sat/docs/solver.md index 6bff2619c0..c538ad9b6b 100644 --- a/ortools/sat/docs/solver.md +++ b/ortools/sat/docs/solver.md @@ -1025,10 +1025,6 @@ void StopAfterNSolutionsSampleSat() { parameters.set_enumerate_all_solutions(true); model.Add(NewSatParameters(parameters)); - // Create an atomic Boolean that will be periodically checked by the limit. - std::atomic stopped(false); - model.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped); - const int kSolutionLimit = 5; int num_solutions = 0; model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) { @@ -1038,7 +1034,7 @@ void StopAfterNSolutionsSampleSat() { LOG(INFO) << " z = " << SolutionIntegerValue(r, z); num_solutions++; if (num_solutions >= kSolutionLimit) { - stopped = true; + StopSearch(&model); LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions."; } })); diff --git a/ortools/sat/model.h b/ortools/sat/model.h index a48c746e0b..88f8140a32 100644 --- a/ortools/sat/model.h +++ b/ortools/sat/model.h @@ -24,6 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/meta/type_traits.h" #include "ortools/base/logging.h" #include "ortools/base/typeid.h"