From d7b250cf5e731fd6e9341f089b9835b6772178d7 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sun, 23 Mar 2025 20:46:46 -0700 Subject: [PATCH] Add StopSearch function in C++ CP-SAT --- examples/cpp/network_routing_sat.cc | 5 +-- examples/cpp/variable_intervals_sat.cc | 6 +--- ortools/sat/BUILD.bazel | 1 - ortools/sat/cp_model_solver.cc | 4 +++ ortools/sat/cp_model_solver.h | 3 ++ ortools/sat/model.cc | 34 ------------------- ortools/sat/model.h | 22 ++++++------ ortools/sat/samples/nurses_sat.cc | 6 +--- .../stop_after_n_solutions_sample_sat.cc | 6 +--- 9 files changed, 23 insertions(+), 64 deletions(-) delete mode 100644 ortools/sat/model.cc diff --git a/examples/cpp/network_routing_sat.cc b/examples/cpp/network_routing_sat.cc index dea1177834..0072f8d4b5 100644 --- a/examples/cpp/network_routing_sat.cc +++ b/examples/cpp/network_routing_sat.cc @@ -403,9 +403,6 @@ class NetworkRoutingSolver { cp_model.AddAllDifferent(node_vars); Model model; - // Create an atomic Boolean that will be periodically checked by the limit. - std::atomic stopped(false); - model.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped); model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) { const int path_id = all_paths_[demand_index].size(); @@ -415,7 +412,7 @@ class NetworkRoutingSolver { all_paths_[demand_index].back().insert(arc); } if (all_paths_[demand_index].size() >= max_paths) { - stopped = true; + StopSearch(&model); } })); diff --git a/examples/cpp/variable_intervals_sat.cc b/examples/cpp/variable_intervals_sat.cc index 3c5d8a25ea..2335f4bb04 100644 --- a/examples/cpp/variable_intervals_sat.cc +++ b/examples/cpp/variable_intervals_sat.cc @@ -53,10 +53,6 @@ void Solve() { parameters.set_enumerate_all_solutions(true); model.Add(NewSatParameters(parameters)); - // Create an atomic Boolean that will be periodically checked by the limit. - std::atomic stopped(false); - model.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped); - const int kSolutionLimit = 100; int num_solutions = 0; model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) { @@ -68,7 +64,7 @@ void Solve() { LOG(INFO) << " start_ins = " << SolutionIntegerValue(r, start_ins); num_solutions++; if (num_solutions >= kSolutionLimit) { - stopped = true; + StopSearch(&model); LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions."; } })); diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 197ca0d33f..22a0567aaa 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -43,7 +43,6 @@ cc_library( cc_library( name = "model", - srcs = ["model.cc"], hdrs = ["model.h"], deps = [ "//ortools/base", diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index bc0aeaaeee..7dce9e5a1c 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -2269,6 +2269,10 @@ std::function NewSatParameters( }; } +void StopSearch(Model* model) { + model->GetOrCreate()->Stop(); +} + namespace { void RegisterSearchStatisticCallback(Model* global_model) { global_model->GetOrCreate() diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h index ca3c306eae..9079b1a10b 100644 --- a/ortools/sat/cp_model_solver.h +++ b/ortools/sat/cp_model_solver.h @@ -128,6 +128,9 @@ std::function NewSatParameters( std::function NewSatParameters( const SatParameters& parameters); +// Stops the current search. +void StopSearch(Model*); + // TODO(user): Clean this up. /// Solves a CpModelProto without any processing. Only used for unit tests. void LoadAndSolveCpModelForTest(const CpModelProto& model_proto, Model* model); diff --git a/ortools/sat/model.cc b/ortools/sat/model.cc deleted file mode 100644 index 21cf5324df..0000000000 --- a/ortools/sat/model.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2010-2025 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ortools/sat/model.h" - -#include - -#include "absl/log/check.h" - -namespace operations_research { -namespace sat { - -void Model::AddNewSingleton(void* new_element, size_t type_id) { - CHECK(singletons_.emplace(type_id, new_element).second) - << "Duplicate type id: " << type_id; -} - -void* Model::GetSingletonOrNullptr(size_t type_id) const { - const auto it = singletons_.find(type_id); - return it != singletons_.end() ? it->second : nullptr; -} - -} // namespace sat -} // namespace operations_research diff --git a/ortools/sat/model.h b/ortools/sat/model.h index e3099a9d26..a48c746e0b 100644 --- a/ortools/sat/model.h +++ b/ortools/sat/model.h @@ -110,15 +110,15 @@ class Model { template T* GetOrCreate() { const size_t type_id = gtl::FastTypeId(); - void* find = GetSingletonOrNullptr(type_id); - if (find != nullptr) { - return static_cast(find); + auto find = singletons_.find(type_id); + if (find != singletons_.end()) { + return static_cast(find->second); } // New element. // TODO(user): directly store std::unique_ptr<> in singletons_? T* new_t = MyNew(0); - AddNewSingleton(new_t, type_id); + singletons_[type_id] = new_t; TakeOwnership(new_t); return new_t; } @@ -130,7 +130,9 @@ class Model { */ template const T* Get() const { - return static_cast(GetSingletonOrNullptr(gtl::FastTypeId())); + const auto& it = singletons_.find(gtl::FastTypeId()); + return it != singletons_.end() ? static_cast(it->second) + : nullptr; } /** @@ -138,7 +140,8 @@ class Model { */ template T* Mutable() const { - return static_cast(GetSingletonOrNullptr(gtl::FastTypeId())); + const auto& it = singletons_.find(gtl::FastTypeId()); + return it != singletons_.end() ? static_cast(it->second) : nullptr; } /** @@ -171,7 +174,9 @@ class Model { */ template void Register(T* non_owned_class) { - AddNewSingleton(non_owned_class, gtl::FastTypeId()); + const size_t type_id = gtl::FastTypeId(); + CHECK(!singletons_.contains(type_id)); + singletons_[type_id] = non_owned_class; } const std::string& Name() const { return name_; } @@ -191,9 +196,6 @@ class Model { return new T(); } - void AddNewSingleton(void* new_element, size_t type_id); - void* GetSingletonOrNullptr(size_t type_id) const; - const std::string name_; // Map of FastTypeId to a "singleton" of type T. diff --git a/ortools/sat/samples/nurses_sat.cc b/ortools/sat/samples/nurses_sat.cc index a4bcfa094e..29197bf1bd 100644 --- a/ortools/sat/samples/nurses_sat.cc +++ b/ortools/sat/samples/nurses_sat.cc @@ -138,10 +138,6 @@ void NurseSat() { // Display the first five solutions. // [START solution_printer] - // Create an atomic Boolean that will be periodically checked by the limit. - std::atomic stopped(false); - model.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped); - const int kSolutionLimit = 5; int num_solutions = 0; model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) { @@ -165,7 +161,7 @@ void NurseSat() { } num_solutions++; if (num_solutions >= kSolutionLimit) { - stopped = true; + StopSearch(&model); LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions."; } })); diff --git a/ortools/sat/samples/stop_after_n_solutions_sample_sat.cc b/ortools/sat/samples/stop_after_n_solutions_sample_sat.cc index 4dfdccfafc..76c55ebe60 100644 --- a/ortools/sat/samples/stop_after_n_solutions_sample_sat.cc +++ b/ortools/sat/samples/stop_after_n_solutions_sample_sat.cc @@ -43,10 +43,6 @@ void StopAfterNSolutionsSampleSat() { parameters.set_enumerate_all_solutions(true); model.Add(NewSatParameters(parameters)); - // Create an atomic Boolean that will be periodically checked by the limit. - std::atomic stopped(false); - model.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped); - const int kSolutionLimit = 5; int num_solutions = 0; model.Add(NewFeasibleSolutionObserver([&](const CpSolverResponse& r) { @@ -56,7 +52,7 @@ void StopAfterNSolutionsSampleSat() { LOG(INFO) << " z = " << SolutionIntegerValue(r, z); num_solutions++; if (num_solutions >= kSolutionLimit) { - stopped = true; + StopSearch(&model); LOG(INFO) << "Stop search after " << kSolutionLimit << " solutions."; } }));