From de5fbc46ab6aedac6bf6a899f4467ba35732f6ef Mon Sep 17 00:00:00 2001 From: Corentin Le Molgat Date: Wed, 20 Aug 2025 11:25:07 +0200 Subject: [PATCH] util/python: add solve interrupter --- ortools/util/python/BUILD.bazel | 81 +++++++ ortools/util/python/py_solve_interrupter.cc | 91 +++++++ ortools/util/python/py_solve_interrupter.h | 116 +++++++++ .../python/py_solve_interrupter_testing.cc | 29 +++ .../python/py_solve_interrupter_testing.h | 62 +++++ .../util/python/pybind_solve_interrupter.cc | 41 ++++ .../python/pybind_solve_interrupter_test.py | 139 +++++++++++ .../pybind_solve_interrupter_testing.cc | 34 +++ ortools/util/python/solve_interrupter.py | 136 +++++++++++ ortools/util/python/solve_interrupter_test.py | 226 ++++++++++++++++++ 10 files changed, 955 insertions(+) create mode 100644 ortools/util/python/py_solve_interrupter.cc create mode 100644 ortools/util/python/py_solve_interrupter.h create mode 100644 ortools/util/python/py_solve_interrupter_testing.cc create mode 100644 ortools/util/python/py_solve_interrupter_testing.h create mode 100644 ortools/util/python/pybind_solve_interrupter.cc create mode 100644 ortools/util/python/pybind_solve_interrupter_test.py create mode 100644 ortools/util/python/pybind_solve_interrupter_testing.cc create mode 100644 ortools/util/python/solve_interrupter.py create mode 100644 ortools/util/python/solve_interrupter_test.py diff --git a/ortools/util/python/BUILD.bazel b/ortools/util/python/BUILD.bazel index 701455047d..26fba1de17 100644 --- a/ortools/util/python/BUILD.bazel +++ b/ortools/util/python/BUILD.bazel @@ -16,6 +16,7 @@ load("@pip_deps//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_python//python:py_library.bzl", "py_library") load("@rules_python//python:py_test.bzl", "py_test") cc_library( @@ -44,3 +45,83 @@ py_test( requirement("absl-py"), ], ) + +cc_library( + name = "py_solve_interrupter_lib", + srcs = ["py_solve_interrupter.cc"], + hdrs = ["py_solve_interrupter.h"], + # This library is not meant to be consumed end users; only by C++ code of + # Clif libraries that needs a SolverInterrupter. + visibility = [ + "//ortools/math_opt/core/python:__subpackages__", + "//ortools/math_opt/python:__subpackages__", + ], + deps = [ + "//ortools/util:solve_interrupter", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/synchronization", + ], +) + +cc_library( + name = "py_solve_interrupter_testing_lib", + testonly = True, + srcs = ["py_solve_interrupter_testing.cc"], + hdrs = ["py_solve_interrupter_testing.h"], + deps = [ + ":py_solve_interrupter_lib", + "@abseil-cpp//absl/base:nullability", + ], +) + +pybind_extension( + name = "pybind_solve_interrupter", + srcs = ["pybind_solve_interrupter.cc"], + # This library is not meant to be consumed end users; only pybind11 code + # that needs a SolverInterrupter. + visibility = [ + "//ortools/math_opt/core/python:__subpackages__", + ], + deps = [":py_solve_interrupter_lib"], +) + +pybind_extension( + name = "pybind_solve_interrupter_testing", + testonly = True, + srcs = ["pybind_solve_interrupter_testing.cc"], + deps = [":py_solve_interrupter_testing_lib"], +) + +py_test( + name = "pybind_solve_interrupter_test", + srcs = ["pybind_solve_interrupter_test.py"], + deps = [ + ":pybind_solve_interrupter", + ":pybind_solve_interrupter_testing", + requirement("absl-py"), + ], +) + +py_library( + name = "solve_interrupter", + srcs = ["solve_interrupter.py"], + visibility = ["//visibility:public"], + deps = [ + ":py_solve_interrupter_lib", + ":pybind_solve_interrupter", + requirement("absl-py"), + ], +) + +py_test( + name = "solve_interrupter_test", + srcs = ["solve_interrupter_test.py"], + deps = [ + ":py_solve_interrupter_testing_lib", + ":pybind_solve_interrupter_testing", + ":solve_interrupter", + requirement("absl-py"), + ], +) diff --git a/ortools/util/python/py_solve_interrupter.cc b/ortools/util/python/py_solve_interrupter.cc new file mode 100644 index 0000000000..20d7aef543 --- /dev/null +++ b/ortools/util/python/py_solve_interrupter.cc @@ -0,0 +1,91 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/util/python/py_solve_interrupter.h" + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/log.h" +#include "absl/synchronization/mutex.h" + +namespace operations_research { + +PySolveInterrupter::PySolveInterrupter() + : callback_(&interrupter_, [this]() { TriggerTargets(); }) {} + +std::vector> +PySolveInterrupter::CleanupAndGetTargets( + const PySolveInterrupter* absl_nullable const to_remove) { + // First get strong references of non-expired targets. Also filter the + // to_remove target if found. + std::vector> + non_expired_targets; + non_expired_targets.reserve(targets_.size()); + for (std::weak_ptr& weak_target : targets_) { + std::shared_ptr strong_target = weak_target.lock(); + if (strong_target != nullptr && strong_target.get() != to_remove) { + non_expired_targets.push_back(std::move(strong_target)); + } + } + + // Then recreates targets weak-references with only the non-expired targets. + // + // Note that we could be more efficient by doing the cleanup in-place but here + // we keep the code simple. In practice the number of targets_ is expected to + // be very low (less than 10). + targets_.clear(); + for (std::shared_ptr& strong_target : + non_expired_targets) { + targets_.push_back(strong_target); + } + + return non_expired_targets; +} + +void PySolveInterrupter::AddTriggerTarget( + absl_nonnull std::shared_ptr target) { + const absl::MutexLock lock(&mutex_); + CleanupAndGetTargets(); + // Note that we don't test if targets_ already contain the interrupter as + // interrupter are triggerable only once. Thus having duplicates won't result + // in visible changes outside. + // + // The way RemoveTriggerTarget() is implemented will remove duplicates as + // well. + targets_.push_back(std::move(target)); +} + +void PySolveInterrupter::RemoveTriggerTarget( + absl_nonnull std::shared_ptr target) { + const absl::MutexLock lock(&mutex_); + CleanupAndGetTargets(/*to_remove=*/target.get()); +} + +void PySolveInterrupter::TriggerTargets() { + std::vector> targets; + { // Limit `lock` scope. + const absl::MutexLock lock(&mutex_); + targets = CleanupAndGetTargets(); + } + + // Call targets without holding mutex_. + for (const absl_nonnull std::shared_ptr& target : + targets) { + target->Interrupt(); + } +} + +} // namespace operations_research diff --git a/ortools/util/python/py_solve_interrupter.h b/ortools/util/python/py_solve_interrupter.h new file mode 100644 index 0000000000..ba1225919b --- /dev/null +++ b/ortools/util/python/py_solve_interrupter.h @@ -0,0 +1,116 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_ +#define OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "ortools/util/solve_interrupter.h" + +namespace operations_research { + +// Simple class that wraps a SolveInterrupter for Clif or pybind11. +// +class PySolveInterrupter { + public: + PySolveInterrupter(); + + PySolveInterrupter(const PySolveInterrupter&) = delete; + PySolveInterrupter& operator=(const PySolveInterrupter&) = delete; + + // Interrupts the solve as soon as possible. + void Interrupt() { interrupter_.Interrupt(); } + + // Returns true if the solve interruption has been requested. + inline bool IsInterrupted() const { return interrupter_.IsInterrupted(); } + + // Triggers the target when this interrupter is triggered. + // + // A std::weak_ptr is kept on the target. Expired std::weak_ptr are cleaned up + // on calls to AddTriggerTarget() and RemoveTriggerTarget(). + // + // Complexity: O(num_targets). + void AddTriggerTarget(absl_nonnull std::shared_ptr target) + ABSL_LOCKS_EXCLUDED(mutex_); + + // Removes the target if not null and present, else do nothing. + // + // Complexity: O(num_targets). + void RemoveTriggerTarget( + absl_nonnull std::shared_ptr target) + ABSL_LOCKS_EXCLUDED(mutex_); + + // Add a callback on the interrupter and returns an id to use to remove it. + // + // See SolveInterrupter::AddInterruptionCallback(). + int64_t AddInterruptionCallback(std::function callback) const { + return interrupter_.AddInterruptionCallback(std::move(callback)).value(); + } + + // Remove a callback previously registered by AddInterruptionCallback(). + // + // See SolveInterrupter::RemoveInterruptionCallback(). + void RemoveInterruptionCallback(int64_t callback_id) const { + interrupter_.RemoveInterruptionCallback( + SolveInterrupter::CallbackId{callback_id}); + } + + // Return the underlying interrupter. This method is not exposed in Python and + // only available for C++ code. + // + // The lifetime of the PySolveInterrupter is controlled by Python. To prevent + // issues where the underlying SolveInterrupter would be destroyed while still + // being pointed to by C++ code, C++ Clif/pybind11 consumer code should take + // references to PySolveInterrupter in an std::shared_ptr that outlives any + // reference. + const SolveInterrupter* absl_nonnull interrupter() const { + return &interrupter_; + } + + private: + // Remove expired targets_, the optional to_remove target and return strong + // references on non-expired targets. + // + // The caller must have acquired mutex_. + // + // Complexity: O(num_targets). + std::vector> + CleanupAndGetTargets(const PySolveInterrupter* absl_nullable to_remove = + nullptr) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Triggers all non-expired targets_ interrupters. + void TriggerTargets() ABSL_LOCKS_EXCLUDED(mutex_); + + SolveInterrupter interrupter_; + + absl::Mutex mutex_; + + // Interrupters to trigger when interrupter_ is triggered. + std::vector> targets_ + ABSL_GUARDED_BY(mutex_); + + // Callback that will trigger all interrupters in target_. + // + // It MUST appear after targets_ and interrupter_ as we want to make sure it + // is destroyed first! + ScopedSolveInterrupterCallback callback_; +}; + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_ diff --git a/ortools/util/python/py_solve_interrupter_testing.cc b/ortools/util/python/py_solve_interrupter_testing.cc new file mode 100644 index 0000000000..1c57248970 --- /dev/null +++ b/ortools/util/python/py_solve_interrupter_testing.cc @@ -0,0 +1,29 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/util/python/py_solve_interrupter_testing.h" + +#include + +#include "ortools/util/python/py_solve_interrupter.h" + +namespace operations_research { + +std::optional IsInterrupted(const PySolveInterrupter* interrupter) { + if (interrupter == nullptr) { + return std::nullopt; + } + return interrupter->IsInterrupted(); +} + +} // namespace operations_research diff --git a/ortools/util/python/py_solve_interrupter_testing.h b/ortools/util/python/py_solve_interrupter_testing.h new file mode 100644 index 0000000000..ae5d2948b1 --- /dev/null +++ b/ortools/util/python/py_solve_interrupter_testing.h @@ -0,0 +1,62 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Library to unit test PySolveInterrupter wrapper. +#ifndef OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_ +#define OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "ortools/util/python/py_solve_interrupter.h" + +namespace operations_research { + +// Returns: +// * nullopt if `interrupter` is nullptr, +// * or false if `interrupter` is not nullptr and is not interrupted, +// * or true if `interrupted` is not nullptr and is interrupted. +// +// The Clif/pybind11 wrapper will return a `bool | None` value, with None for +// nullopt. +std::optional IsInterrupted(const PySolveInterrupter* interrupter); + +// Class that keeps a reference on a std::shared_ptr to +// test that the C++ object survive the cleanup of the Python reference. +class PySolveInterrupterReference { + public: + explicit PySolveInterrupterReference( + absl_nonnull std::shared_ptr interrupter) + : interrupter_(std::move(interrupter)) {} + + PySolveInterrupterReference(const PySolveInterrupterReference&) = delete; + PySolveInterrupterReference& operator=(const PySolveInterrupterReference&) = + delete; + + // Returns the std::shared_ptr::use_count(). + // + // This is used to test that Python has stopped pointing to the object. + int use_count() const { return interrupter_.use_count(); } + + // Returns true if the underlying interrupter is interrupted. + bool is_interrupted() const { return interrupter_->IsInterrupted(); } + + private: + const absl_nonnull std::shared_ptr interrupter_; +}; + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_ diff --git a/ortools/util/python/pybind_solve_interrupter.cc b/ortools/util/python/pybind_solve_interrupter.cc new file mode 100644 index 0000000000..dbc0514202 --- /dev/null +++ b/ortools/util/python/pybind_solve_interrupter.cc @@ -0,0 +1,41 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "ortools/util/python/py_solve_interrupter.h" + +namespace operations_research { + +namespace py = ::pybind11; + +PYBIND11_MODULE(pybind_solve_interrupter, m) { + py::class_>( + m, "PySolveInterrupter") + .def(py::init()) + .def("interrupt", &PySolveInterrupter::Interrupt) + .def_property_readonly("interrupted", &PySolveInterrupter::IsInterrupted) + .def("add_trigger_target", &PySolveInterrupter::AddTriggerTarget, + py::arg("target")) + .def("remove_trigger_target", &PySolveInterrupter::RemoveTriggerTarget, + py::arg("target")) + .def("add_interruption_callback", + &PySolveInterrupter::AddInterruptionCallback, py::arg("callback")) + .def("remove_interruption_callback", + &PySolveInterrupter::RemoveInterruptionCallback, + py::arg("callback_id")); +} + +} // namespace operations_research diff --git a/ortools/util/python/pybind_solve_interrupter_test.py b/ortools/util/python/pybind_solve_interrupter_test.py new file mode 100644 index 0000000000..677ae5141d --- /dev/null +++ b/ortools/util/python/pybind_solve_interrupter_test.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from ortools.util.python import pybind_solve_interrupter +from ortools.util.python import pybind_solve_interrupter_testing + + +class PybindPySolveInterrupterTest(absltest.TestCase): + """Test pybind_solve_interrupter.PySolveInterrupter class. + + This tests how pybind_solve_interrupter.PySolveInterrupter is accepted by + other + pybind-generated APIs in Python, using pybind_solve_interrupter_testing to do + so. + """ + + def test_none_interrupter(self) -> None: + """Test that a `const PySolveInterrupter*` can receive nullptr.""" + # Here IsInterrupted() is a Clif wrapped C++ function that expects a `const + # PySolveInterrupter*`. It returns None (std::nullopt) for nullptr + # input. + self.assertIsNone(pybind_solve_interrupter_testing.IsInterrupted(None)) + + def test_untriggered_interrupter(self) -> None: + """Test that an untriggered interrupter is properly passed.""" + interrupter = pybind_solve_interrupter.PySolveInterrupter() + # Test the getter. + self.assertFalse(interrupter.interrupted) + # Test using a Clif wrapped C++ function. We test it is not None to make + # sure Clif passes a non-null pointer to the C++ code. We thus have to use + # assertIs() instead of assertFalse(). + self.assertIs( + pybind_solve_interrupter_testing.IsInterrupted(interrupter), False + ) + + def test_triggered_interrupter(self) -> None: + """Test that an untriggered interrupter is properly passed.""" + interrupter = pybind_solve_interrupter.PySolveInterrupter() + interrupter.interrupt() + # Test the getter. + self.assertTrue(interrupter.interrupted) + # Test using a Clif wrapped C++ function. + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(interrupter)) + + def test_shared_reference(self) -> None: + """Test that taking a std::shared_ptr works.""" + # Create an interrupter and a PySolveInterrupterReference class which + # constructor takes and std::shared_ptr. We expect + # that only one instance of the C++ PySolveInterrupter will exist here. + interrupter = pybind_solve_interrupter.PySolveInterrupter() + interrupter_ref = pybind_solve_interrupter_testing.PySolveInterrupterReference( + interrupter + ) + + # Validate that we have the expected number of references of the shared_ptr + # hold by PySolveInterrupterReference. + self.assertEqual(interrupter_ref.use_count, 2) + self.assertFalse(interrupter_ref.is_interrupted) + + # Triggering the interrupter should be visible the pointed + # PySolveInterrupter in C++. + interrupter.interrupt() + self.assertEqual(interrupter_ref.use_count, 2) + self.assertTrue(interrupter_ref.is_interrupted) + + # Removing the Python `interrupter` reference should make `interrupter_ref` + # the only object that holds an std::shared_ptr on the interrupter. + del interrupter + self.assertEqual(interrupter_ref.use_count, 1) + self.assertTrue(interrupter_ref.is_interrupted) + + def test_add_target(self) -> None: + source = pybind_solve_interrupter.PySolveInterrupter() + target = pybind_solve_interrupter.PySolveInterrupter() + source.add_trigger_target(target) + source.interrupt() + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source)) + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(target)) + + def test_remove_existing_target(self) -> None: + source = pybind_solve_interrupter.PySolveInterrupter() + target = pybind_solve_interrupter.PySolveInterrupter() + source.add_trigger_target(target) + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source)) + self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target)) + + def test_remove_existing_target_added_twice(self) -> None: + source = pybind_solve_interrupter.PySolveInterrupter() + target = pybind_solve_interrupter.PySolveInterrupter() + source.add_trigger_target(target) + source.add_trigger_target(target) + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source)) + self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target)) + + def test_dead_target(self) -> None: + source = pybind_solve_interrupter.PySolveInterrupter() + target = pybind_solve_interrupter.PySolveInterrupter() + source.add_trigger_target(target) + del target + source.interrupt() + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source)) + + def test_remove_existing_target_twice(self) -> None: + source = pybind_solve_interrupter.PySolveInterrupter() + target = pybind_solve_interrupter.PySolveInterrupter() + source.add_trigger_target(target) + source.remove_trigger_target(target) + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source)) + self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target)) + + def test_remove_non_target(self) -> None: + source = pybind_solve_interrupter.PySolveInterrupter() + target = pybind_solve_interrupter.PySolveInterrupter() + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source)) + self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target)) + + +if __name__ == "__main__": + absltest.main() diff --git a/ortools/util/python/pybind_solve_interrupter_testing.cc b/ortools/util/python/pybind_solve_interrupter_testing.cc new file mode 100644 index 0000000000..3fe1f10e0e --- /dev/null +++ b/ortools/util/python/pybind_solve_interrupter_testing.cc @@ -0,0 +1,34 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "ortools/util/python/py_solve_interrupter_testing.h" + +namespace operations_research { + +namespace py = ::pybind11; + +PYBIND11_MODULE(pybind_solve_interrupter_testing, m) { + py::class_(m, "PySolveInterrupterReference") + .def(py::init>()) + .def_property_readonly("use_count", + &PySolveInterrupterReference::use_count) + .def_property_readonly("is_interrupted", + &PySolveInterrupterReference::is_interrupted); + m.def("IsInterrupted", &IsInterrupted, py::arg("interrupter")); +} + +} // namespace operations_research diff --git a/ortools/util/python/solve_interrupter.py b/ortools/util/python/solve_interrupter.py new file mode 100644 index 0000000000..19c629947c --- /dev/null +++ b/ortools/util/python/solve_interrupter.py @@ -0,0 +1,136 @@ +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python interrupter for solves.""" + +from collections.abc import Callable, Iterator +import contextlib +from typing import Optional + +from absl import logging + +from ortools.util.python import pybind_solve_interrupter + + +class CallbackError(Exception): + """Exception raised when an interrupter callback fails. + + When using SolveInterrupter.interruption_callback(), this exception is raised + when exiting the context manager if the callback failed. The error in the + callback is the cause of this exception. + """ + + +class SolveInterrupter: + """Interrupter used by solvers to know when they should interrupt the solve. + + Once triggered with interrupt(), an interrupter can't be reset. It can be + triggered from any thread. + + Thread-safety: APIs on this class are safe to call concurrently from multiple + threads. + + Attributes: + pybind_interrupter: The pybind wrapper around PySolveInterrupter. + """ + + def __init__(self) -> None: + self.pybind_interrupter = pybind_solve_interrupter.PySolveInterrupter() + + def interrupt(self) -> None: + """Interrupts the solve as soon as possible. + + Once requested the interruption can't be reset. The user should use a new + SolveInterrupter for later solves. + + It is safe to call this function multiple times. Only the first call will + have visible effects; other calls will be ignored. + """ + self.pybind_interrupter.interrupt() + + @property + def interrupted(self) -> bool: + """True if the solve interruption has been requested.""" + return self.pybind_interrupter.interrupted + + def add_trigger_target(self, target: "SolveInterrupter") -> None: + """Triggers the target when this interrupter is triggered.""" + self.pybind_interrupter.add_trigger_target(target.pybind_interrupter) + + def remove_trigger_target(self, target: "SolveInterrupter") -> None: + """Removes the target if not null and present, else do nothing.""" + self.pybind_interrupter.remove_trigger_target(target.pybind_interrupter) + + @contextlib.contextmanager + def interruption_callback(self, callback: Callable[[], None]) -> Iterator[None]: + """Returns a context manager that (un)register the provided callback. + + The callback is immediately called if the interrupter has already been + triggered or if it is triggered during the registration. This is typically + useful for a solver implementation so that it does not have to test + `interrupted` to do the same thing it does in the callback. Simply + registering the callback is enough. + + The callback function can't make calls to interruption_callback(), and + interrupt(). This would result is a deadlock. Reading `interrupted` is fine + though. + + Exceptions raised in the callback are raised on exit from the context + manager if no other error happens within the context. Else the exception is + logged. + + Args: + callback: The callback. + + Returns: + A context manager. + + Raises: + CallbackError: When exiting the context manager if an exception was raised + in the callback. + """ + callback_error: Optional[Exception] = None + + def protetected_callback(): + """Calls callback() storing any exception in callback_error.""" + nonlocal callback_error + try: + callback() + except Exception as e: # pylint: disable=broad-exception-caught + # It is fine to set callback_error without any threading protection as + # the SolveInterrupter guarantees it will only call it at most once. + callback_error = e + + callback_id = self.pybind_interrupter.add_interruption_callback( + protetected_callback + ) + + no_exception_in_context = False + try: + yield + no_exception_in_context = True + finally: + self.pybind_interrupter.remove_interruption_callback(callback_id) + # It is fine to access callback_error without threading protection after + # remove_interruption_callback() has returned as the SolveInterrupter + # guarantees any pending call is done and no future call can happen. + if callback_error is not None: + if no_exception_in_context: + raise CallbackError() from callback_error + # We don't want the error in the context to be masked by an error in the + # callback. We log it instead. + logging.error( + "An exception occurred in callback but is masked by another" + " exception: %s", + repr(callback_error), + ) diff --git a/ortools/util/python/solve_interrupter_test.py b/ortools/util/python/solve_interrupter_test.py new file mode 100644 index 0000000000..d9008b8de7 --- /dev/null +++ b/ortools/util/python/solve_interrupter_test.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from ortools.util.python import pybind_solve_interrupter_testing +from ortools.util.python import solve_interrupter + + +class SolveInterrupterTest(absltest.TestCase): + + def test_untriggered_interrupter(self) -> None: + interrupter = solve_interrupter.SolveInterrupter() + self.assertFalse(interrupter.interrupted) + + def test_triggered_interrupter(self) -> None: + interrupter = solve_interrupter.SolveInterrupter() + interrupter.interrupt() + self.assertTrue(interrupter.interrupted) + + def test_add_target(self) -> None: + source = solve_interrupter.SolveInterrupter() + target = solve_interrupter.SolveInterrupter() + source.add_trigger_target(target) + source.interrupt() + self.assertTrue(source.interrupted) + self.assertTrue(target.interrupted) + + def test_remove_existing_target(self) -> None: + source = solve_interrupter.SolveInterrupter() + target = solve_interrupter.SolveInterrupter() + source.add_trigger_target(target) + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(source.interrupted) + self.assertFalse(target.interrupted) + + def test_remove_existing_target_added_twice(self) -> None: + source = solve_interrupter.SolveInterrupter() + target = solve_interrupter.SolveInterrupter() + source.add_trigger_target(target) + source.add_trigger_target(target) + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(source.interrupted) + self.assertFalse(target.interrupted) + + def test_dead_target(self) -> None: + source = solve_interrupter.SolveInterrupter() + target = solve_interrupter.SolveInterrupter() + source.add_trigger_target(target) + del target + source.interrupt() + self.assertTrue(source.interrupted) + + def test_remove_existing_target_twice(self) -> None: + source = solve_interrupter.SolveInterrupter() + target = solve_interrupter.SolveInterrupter() + source.add_trigger_target(target) + source.remove_trigger_target(target) + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(source.interrupted) + self.assertFalse(target.interrupted) + + def test_remove_non_target(self) -> None: + source = solve_interrupter.SolveInterrupter() + target = solve_interrupter.SolveInterrupter() + source.remove_trigger_target(target) + source.interrupt() + self.assertTrue(source.interrupted) + self.assertFalse(target.interrupted) + + def test_callback_already_interrupted(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + + interrupter = solve_interrupter.SolveInterrupter() + interrupter.interrupt() + + with interrupter.interruption_callback(callback): + self.assertEqual(num_calls, 1) + + self.assertEqual(num_calls, 1) + + def test_callback_interruption(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + + interrupter = solve_interrupter.SolveInterrupter() + + with interrupter.interruption_callback(callback): + self.assertEqual(num_calls, 0) + interrupter.interrupt() + self.assertEqual(num_calls, 1) + + self.assertEqual(num_calls, 1) + + def test_callback_interruption_after_removal(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + + interrupter = solve_interrupter.SolveInterrupter() + + with interrupter.interruption_callback(callback): + self.assertEqual(num_calls, 0) + + interrupter.interrupt() + self.assertEqual(num_calls, 0) + + def test_callback_nointerruption(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + + interrupter = solve_interrupter.SolveInterrupter() + + with interrupter.interruption_callback(callback): + self.assertEqual(num_calls, 0) + + self.assertEqual(num_calls, 0) + + def test_callback_with_exception_in_callback(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + raise ValueError("error-in-callback") + + interrupter = solve_interrupter.SolveInterrupter() + + # has_finished is set to true after the call to interrupter.interrupt() and + # will be used to validate that interrupt() does not raise the exception, + # only the __exit__() of interruption_callback() context does. + has_finished = False + with self.assertRaises(solve_interrupter.CallbackError) as cm: + with interrupter.interruption_callback(callback): + before_interrupt_num_calls = num_calls + interrupter.interrupt() + after_interrupt_num_calls = num_calls + has_finished = True + # Test the cause of the CallbackError, which should be the original error. + self.assertIsInstance(cm.exception.__cause__, ValueError) + self.assertEqual(str(cm.exception.__cause__), "error-in-callback") + + self.assertEqual(before_interrupt_num_calls, 0) + self.assertEqual(after_interrupt_num_calls, 1) + self.assertTrue(has_finished) + self.assertEqual(num_calls, 1) + + def test_callback_with_exception_in_context(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + + interrupter = solve_interrupter.SolveInterrupter() + + with self.assertRaisesRegex(ValueError, "error-in-context"): + with interrupter.interruption_callback(callback): + raise ValueError("error-in-context") + + interrupter.interrupt() + self.assertEqual(num_calls, 0) + + def test_callback_with_exception_in_callback_and_context(self) -> None: + num_calls = 0 + + def callback(): + nonlocal num_calls + num_calls += 1 + raise ValueError("error-in-callback") + + interrupter = solve_interrupter.SolveInterrupter() + + with self.assertRaisesRegex(ValueError, "error-in-context"): + with interrupter.interruption_callback(callback): + before_interrupt_num_calls = num_calls + interrupter.interrupt() + after_interrupt_num_calls = num_calls + raise ValueError("error-in-context") + + self.assertEqual(before_interrupt_num_calls, 0) + self.assertEqual(after_interrupt_num_calls, 1) + self.assertEqual(num_calls, 1) + + def test_pybind_interrupter(self) -> None: + interrupter = solve_interrupter.SolveInterrupter() + pybind_interrupter = interrupter.pybind_interrupter + + self.assertFalse( + pybind_solve_interrupter_testing.IsInterrupted(pybind_interrupter) + ) + + interrupter.interrupt() + + self.assertTrue( + pybind_solve_interrupter_testing.IsInterrupted(pybind_interrupter) + ) + + +if __name__ == "__main__": + absltest.main()