util/python: add solve interrupter

This commit is contained in:
Corentin Le Molgat
2025-08-20 11:25:07 +02:00
parent fcf4bd181e
commit de5fbc46ab
10 changed files with 955 additions and 0 deletions

View File

@@ -16,6 +16,7 @@
load("@pip_deps//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_python//python:py_library.bzl", "py_library")
load("@rules_python//python:py_test.bzl", "py_test")
cc_library(
@@ -44,3 +45,83 @@ py_test(
requirement("absl-py"),
],
)
cc_library(
name = "py_solve_interrupter_lib",
srcs = ["py_solve_interrupter.cc"],
hdrs = ["py_solve_interrupter.h"],
# This library is not meant to be consumed end users; only by C++ code of
# Clif libraries that needs a SolverInterrupter.
visibility = [
"//ortools/math_opt/core/python:__subpackages__",
"//ortools/math_opt/python:__subpackages__",
],
deps = [
"//ortools/util:solve_interrupter",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/base:nullability",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/synchronization",
],
)
cc_library(
name = "py_solve_interrupter_testing_lib",
testonly = True,
srcs = ["py_solve_interrupter_testing.cc"],
hdrs = ["py_solve_interrupter_testing.h"],
deps = [
":py_solve_interrupter_lib",
"@abseil-cpp//absl/base:nullability",
],
)
pybind_extension(
name = "pybind_solve_interrupter",
srcs = ["pybind_solve_interrupter.cc"],
# This library is not meant to be consumed end users; only pybind11 code
# that needs a SolverInterrupter.
visibility = [
"//ortools/math_opt/core/python:__subpackages__",
],
deps = [":py_solve_interrupter_lib"],
)
pybind_extension(
name = "pybind_solve_interrupter_testing",
testonly = True,
srcs = ["pybind_solve_interrupter_testing.cc"],
deps = [":py_solve_interrupter_testing_lib"],
)
py_test(
name = "pybind_solve_interrupter_test",
srcs = ["pybind_solve_interrupter_test.py"],
deps = [
":pybind_solve_interrupter",
":pybind_solve_interrupter_testing",
requirement("absl-py"),
],
)
py_library(
name = "solve_interrupter",
srcs = ["solve_interrupter.py"],
visibility = ["//visibility:public"],
deps = [
":py_solve_interrupter_lib",
":pybind_solve_interrupter",
requirement("absl-py"),
],
)
py_test(
name = "solve_interrupter_test",
srcs = ["solve_interrupter_test.py"],
deps = [
":py_solve_interrupter_testing_lib",
":pybind_solve_interrupter_testing",
":solve_interrupter",
requirement("absl-py"),
],
)

View File

@@ -0,0 +1,91 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/util/python/py_solve_interrupter.h"
#include <memory>
#include <utility>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/log/log.h"
#include "absl/synchronization/mutex.h"
namespace operations_research {
PySolveInterrupter::PySolveInterrupter()
: callback_(&interrupter_, [this]() { TriggerTargets(); }) {}
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>>
PySolveInterrupter::CleanupAndGetTargets(
const PySolveInterrupter* absl_nullable const to_remove) {
// First get strong references of non-expired targets. Also filter the
// to_remove target if found.
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>>
non_expired_targets;
non_expired_targets.reserve(targets_.size());
for (std::weak_ptr<PySolveInterrupter>& weak_target : targets_) {
std::shared_ptr<PySolveInterrupter> strong_target = weak_target.lock();
if (strong_target != nullptr && strong_target.get() != to_remove) {
non_expired_targets.push_back(std::move(strong_target));
}
}
// Then recreates targets weak-references with only the non-expired targets.
//
// Note that we could be more efficient by doing the cleanup in-place but here
// we keep the code simple. In practice the number of targets_ is expected to
// be very low (less than 10).
targets_.clear();
for (std::shared_ptr<PySolveInterrupter>& strong_target :
non_expired_targets) {
targets_.push_back(strong_target);
}
return non_expired_targets;
}
void PySolveInterrupter::AddTriggerTarget(
absl_nonnull std::shared_ptr<PySolveInterrupter> target) {
const absl::MutexLock lock(&mutex_);
CleanupAndGetTargets();
// Note that we don't test if targets_ already contain the interrupter as
// interrupter are triggerable only once. Thus having duplicates won't result
// in visible changes outside.
//
// The way RemoveTriggerTarget() is implemented will remove duplicates as
// well.
targets_.push_back(std::move(target));
}
void PySolveInterrupter::RemoveTriggerTarget(
absl_nonnull std::shared_ptr<PySolveInterrupter> target) {
const absl::MutexLock lock(&mutex_);
CleanupAndGetTargets(/*to_remove=*/target.get());
}
void PySolveInterrupter::TriggerTargets() {
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>> targets;
{ // Limit `lock` scope.
const absl::MutexLock lock(&mutex_);
targets = CleanupAndGetTargets();
}
// Call targets without holding mutex_.
for (const absl_nonnull std::shared_ptr<PySolveInterrupter>& target :
targets) {
target->Interrupt();
}
}
} // namespace operations_research

View File

@@ -0,0 +1,116 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_
#define OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_
#include <memory>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "ortools/util/solve_interrupter.h"
namespace operations_research {
// Simple class that wraps a SolveInterrupter for Clif or pybind11.
//
class PySolveInterrupter {
public:
PySolveInterrupter();
PySolveInterrupter(const PySolveInterrupter&) = delete;
PySolveInterrupter& operator=(const PySolveInterrupter&) = delete;
// Interrupts the solve as soon as possible.
void Interrupt() { interrupter_.Interrupt(); }
// Returns true if the solve interruption has been requested.
inline bool IsInterrupted() const { return interrupter_.IsInterrupted(); }
// Triggers the target when this interrupter is triggered.
//
// A std::weak_ptr is kept on the target. Expired std::weak_ptr are cleaned up
// on calls to AddTriggerTarget() and RemoveTriggerTarget().
//
// Complexity: O(num_targets).
void AddTriggerTarget(absl_nonnull std::shared_ptr<PySolveInterrupter> target)
ABSL_LOCKS_EXCLUDED(mutex_);
// Removes the target if not null and present, else do nothing.
//
// Complexity: O(num_targets).
void RemoveTriggerTarget(
absl_nonnull std::shared_ptr<PySolveInterrupter> target)
ABSL_LOCKS_EXCLUDED(mutex_);
// Add a callback on the interrupter and returns an id to use to remove it.
//
// See SolveInterrupter::AddInterruptionCallback().
int64_t AddInterruptionCallback(std::function<void()> callback) const {
return interrupter_.AddInterruptionCallback(std::move(callback)).value();
}
// Remove a callback previously registered by AddInterruptionCallback().
//
// See SolveInterrupter::RemoveInterruptionCallback().
void RemoveInterruptionCallback(int64_t callback_id) const {
interrupter_.RemoveInterruptionCallback(
SolveInterrupter::CallbackId{callback_id});
}
// Return the underlying interrupter. This method is not exposed in Python and
// only available for C++ code.
//
// The lifetime of the PySolveInterrupter is controlled by Python. To prevent
// issues where the underlying SolveInterrupter would be destroyed while still
// being pointed to by C++ code, C++ Clif/pybind11 consumer code should take
// references to PySolveInterrupter in an std::shared_ptr that outlives any
// reference.
const SolveInterrupter* absl_nonnull interrupter() const {
return &interrupter_;
}
private:
// Remove expired targets_, the optional to_remove target and return strong
// references on non-expired targets.
//
// The caller must have acquired mutex_.
//
// Complexity: O(num_targets).
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>>
CleanupAndGetTargets(const PySolveInterrupter* absl_nullable to_remove =
nullptr) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Triggers all non-expired targets_ interrupters.
void TriggerTargets() ABSL_LOCKS_EXCLUDED(mutex_);
SolveInterrupter interrupter_;
absl::Mutex mutex_;
// Interrupters to trigger when interrupter_ is triggered.
std::vector<std::weak_ptr<PySolveInterrupter>> targets_
ABSL_GUARDED_BY(mutex_);
// Callback that will trigger all interrupters in target_.
//
// It MUST appear after targets_ and interrupter_ as we want to make sure it
// is destroyed first!
ScopedSolveInterrupterCallback callback_;
};
} // namespace operations_research
#endif // OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_

View File

@@ -0,0 +1,29 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/util/python/py_solve_interrupter_testing.h"
#include <optional>
#include "ortools/util/python/py_solve_interrupter.h"
namespace operations_research {
std::optional<bool> IsInterrupted(const PySolveInterrupter* interrupter) {
if (interrupter == nullptr) {
return std::nullopt;
}
return interrupter->IsInterrupted();
}
} // namespace operations_research

View File

@@ -0,0 +1,62 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Library to unit test PySolveInterrupter wrapper.
#ifndef OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_
#define OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_
#include <memory>
#include <optional>
#include <utility>
#include "absl/base/nullability.h"
#include "ortools/util/python/py_solve_interrupter.h"
namespace operations_research {
// Returns:
// * nullopt if `interrupter` is nullptr,
// * or false if `interrupter` is not nullptr and is not interrupted,
// * or true if `interrupted` is not nullptr and is interrupted.
//
// The Clif/pybind11 wrapper will return a `bool | None` value, with None for
// nullopt.
std::optional<bool> IsInterrupted(const PySolveInterrupter* interrupter);
// Class that keeps a reference on a std::shared_ptr<PySolveInterrupter> to
// test that the C++ object survive the cleanup of the Python reference.
class PySolveInterrupterReference {
public:
explicit PySolveInterrupterReference(
absl_nonnull std::shared_ptr<PySolveInterrupter> interrupter)
: interrupter_(std::move(interrupter)) {}
PySolveInterrupterReference(const PySolveInterrupterReference&) = delete;
PySolveInterrupterReference& operator=(const PySolveInterrupterReference&) =
delete;
// Returns the std::shared_ptr<PySolveInterrupter>::use_count().
//
// This is used to test that Python has stopped pointing to the object.
int use_count() const { return interrupter_.use_count(); }
// Returns true if the underlying interrupter is interrupted.
bool is_interrupted() const { return interrupter_->IsInterrupted(); }
private:
const absl_nonnull std::shared_ptr<PySolveInterrupter> interrupter_;
};
} // namespace operations_research
#endif // OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_

View File

@@ -0,0 +1,41 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ortools/util/python/py_solve_interrupter.h"
namespace operations_research {
namespace py = ::pybind11;
PYBIND11_MODULE(pybind_solve_interrupter, m) {
py::class_<PySolveInterrupter, std::shared_ptr<PySolveInterrupter>>(
m, "PySolveInterrupter")
.def(py::init())
.def("interrupt", &PySolveInterrupter::Interrupt)
.def_property_readonly("interrupted", &PySolveInterrupter::IsInterrupted)
.def("add_trigger_target", &PySolveInterrupter::AddTriggerTarget,
py::arg("target"))
.def("remove_trigger_target", &PySolveInterrupter::RemoveTriggerTarget,
py::arg("target"))
.def("add_interruption_callback",
&PySolveInterrupter::AddInterruptionCallback, py::arg("callback"))
.def("remove_interruption_callback",
&PySolveInterrupter::RemoveInterruptionCallback,
py::arg("callback_id"));
}
} // namespace operations_research

View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python3
# Copyright 2010-2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from ortools.util.python import pybind_solve_interrupter
from ortools.util.python import pybind_solve_interrupter_testing
class PybindPySolveInterrupterTest(absltest.TestCase):
"""Test pybind_solve_interrupter.PySolveInterrupter class.
This tests how pybind_solve_interrupter.PySolveInterrupter is accepted by
other
pybind-generated APIs in Python, using pybind_solve_interrupter_testing to do
so.
"""
def test_none_interrupter(self) -> None:
"""Test that a `const PySolveInterrupter*` can receive nullptr."""
# Here IsInterrupted() is a Clif wrapped C++ function that expects a `const
# PySolveInterrupter*`. It returns None (std::nullopt) for nullptr
# input.
self.assertIsNone(pybind_solve_interrupter_testing.IsInterrupted(None))
def test_untriggered_interrupter(self) -> None:
"""Test that an untriggered interrupter is properly passed."""
interrupter = pybind_solve_interrupter.PySolveInterrupter()
# Test the getter.
self.assertFalse(interrupter.interrupted)
# Test using a Clif wrapped C++ function. We test it is not None to make
# sure Clif passes a non-null pointer to the C++ code. We thus have to use
# assertIs() instead of assertFalse().
self.assertIs(
pybind_solve_interrupter_testing.IsInterrupted(interrupter), False
)
def test_triggered_interrupter(self) -> None:
"""Test that an untriggered interrupter is properly passed."""
interrupter = pybind_solve_interrupter.PySolveInterrupter()
interrupter.interrupt()
# Test the getter.
self.assertTrue(interrupter.interrupted)
# Test using a Clif wrapped C++ function.
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(interrupter))
def test_shared_reference(self) -> None:
"""Test that taking a std::shared_ptr<PySolveInterrupter> works."""
# Create an interrupter and a PySolveInterrupterReference class which
# constructor takes and std::shared_ptr<PySolveInterrupter>. We expect
# that only one instance of the C++ PySolveInterrupter will exist here.
interrupter = pybind_solve_interrupter.PySolveInterrupter()
interrupter_ref = pybind_solve_interrupter_testing.PySolveInterrupterReference(
interrupter
)
# Validate that we have the expected number of references of the shared_ptr
# hold by PySolveInterrupterReference.
self.assertEqual(interrupter_ref.use_count, 2)
self.assertFalse(interrupter_ref.is_interrupted)
# Triggering the interrupter should be visible the pointed
# PySolveInterrupter in C++.
interrupter.interrupt()
self.assertEqual(interrupter_ref.use_count, 2)
self.assertTrue(interrupter_ref.is_interrupted)
# Removing the Python `interrupter` reference should make `interrupter_ref`
# the only object that holds an std::shared_ptr on the interrupter.
del interrupter
self.assertEqual(interrupter_ref.use_count, 1)
self.assertTrue(interrupter_ref.is_interrupted)
def test_add_target(self) -> None:
source = pybind_solve_interrupter.PySolveInterrupter()
target = pybind_solve_interrupter.PySolveInterrupter()
source.add_trigger_target(target)
source.interrupt()
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(target))
def test_remove_existing_target(self) -> None:
source = pybind_solve_interrupter.PySolveInterrupter()
target = pybind_solve_interrupter.PySolveInterrupter()
source.add_trigger_target(target)
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
def test_remove_existing_target_added_twice(self) -> None:
source = pybind_solve_interrupter.PySolveInterrupter()
target = pybind_solve_interrupter.PySolveInterrupter()
source.add_trigger_target(target)
source.add_trigger_target(target)
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
def test_dead_target(self) -> None:
source = pybind_solve_interrupter.PySolveInterrupter()
target = pybind_solve_interrupter.PySolveInterrupter()
source.add_trigger_target(target)
del target
source.interrupt()
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
def test_remove_existing_target_twice(self) -> None:
source = pybind_solve_interrupter.PySolveInterrupter()
target = pybind_solve_interrupter.PySolveInterrupter()
source.add_trigger_target(target)
source.remove_trigger_target(target)
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
def test_remove_non_target(self) -> None:
source = pybind_solve_interrupter.PySolveInterrupter()
target = pybind_solve_interrupter.PySolveInterrupter()
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
if __name__ == "__main__":
absltest.main()

View File

@@ -0,0 +1,34 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ortools/util/python/py_solve_interrupter_testing.h"
namespace operations_research {
namespace py = ::pybind11;
PYBIND11_MODULE(pybind_solve_interrupter_testing, m) {
py::class_<PySolveInterrupterReference>(m, "PySolveInterrupterReference")
.def(py::init<std::shared_ptr<PySolveInterrupter>>())
.def_property_readonly("use_count",
&PySolveInterrupterReference::use_count)
.def_property_readonly("is_interrupted",
&PySolveInterrupterReference::is_interrupted);
m.def("IsInterrupted", &IsInterrupted, py::arg("interrupter"));
}
} // namespace operations_research

View File

@@ -0,0 +1,136 @@
# Copyright 2010-2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python interrupter for solves."""
from collections.abc import Callable, Iterator
import contextlib
from typing import Optional
from absl import logging
from ortools.util.python import pybind_solve_interrupter
class CallbackError(Exception):
"""Exception raised when an interrupter callback fails.
When using SolveInterrupter.interruption_callback(), this exception is raised
when exiting the context manager if the callback failed. The error in the
callback is the cause of this exception.
"""
class SolveInterrupter:
"""Interrupter used by solvers to know when they should interrupt the solve.
Once triggered with interrupt(), an interrupter can't be reset. It can be
triggered from any thread.
Thread-safety: APIs on this class are safe to call concurrently from multiple
threads.
Attributes:
pybind_interrupter: The pybind wrapper around PySolveInterrupter.
"""
def __init__(self) -> None:
self.pybind_interrupter = pybind_solve_interrupter.PySolveInterrupter()
def interrupt(self) -> None:
"""Interrupts the solve as soon as possible.
Once requested the interruption can't be reset. The user should use a new
SolveInterrupter for later solves.
It is safe to call this function multiple times. Only the first call will
have visible effects; other calls will be ignored.
"""
self.pybind_interrupter.interrupt()
@property
def interrupted(self) -> bool:
"""True if the solve interruption has been requested."""
return self.pybind_interrupter.interrupted
def add_trigger_target(self, target: "SolveInterrupter") -> None:
"""Triggers the target when this interrupter is triggered."""
self.pybind_interrupter.add_trigger_target(target.pybind_interrupter)
def remove_trigger_target(self, target: "SolveInterrupter") -> None:
"""Removes the target if not null and present, else do nothing."""
self.pybind_interrupter.remove_trigger_target(target.pybind_interrupter)
@contextlib.contextmanager
def interruption_callback(self, callback: Callable[[], None]) -> Iterator[None]:
"""Returns a context manager that (un)register the provided callback.
The callback is immediately called if the interrupter has already been
triggered or if it is triggered during the registration. This is typically
useful for a solver implementation so that it does not have to test
`interrupted` to do the same thing it does in the callback. Simply
registering the callback is enough.
The callback function can't make calls to interruption_callback(), and
interrupt(). This would result is a deadlock. Reading `interrupted` is fine
though.
Exceptions raised in the callback are raised on exit from the context
manager if no other error happens within the context. Else the exception is
logged.
Args:
callback: The callback.
Returns:
A context manager.
Raises:
CallbackError: When exiting the context manager if an exception was raised
in the callback.
"""
callback_error: Optional[Exception] = None
def protetected_callback():
"""Calls callback() storing any exception in callback_error."""
nonlocal callback_error
try:
callback()
except Exception as e: # pylint: disable=broad-exception-caught
# It is fine to set callback_error without any threading protection as
# the SolveInterrupter guarantees it will only call it at most once.
callback_error = e
callback_id = self.pybind_interrupter.add_interruption_callback(
protetected_callback
)
no_exception_in_context = False
try:
yield
no_exception_in_context = True
finally:
self.pybind_interrupter.remove_interruption_callback(callback_id)
# It is fine to access callback_error without threading protection after
# remove_interruption_callback() has returned as the SolveInterrupter
# guarantees any pending call is done and no future call can happen.
if callback_error is not None:
if no_exception_in_context:
raise CallbackError() from callback_error
# We don't want the error in the context to be masked by an error in the
# callback. We log it instead.
logging.error(
"An exception occurred in callback but is masked by another"
" exception: %s",
repr(callback_error),
)

View File

@@ -0,0 +1,226 @@
#!/usr/bin/env python3
# Copyright 2010-2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from ortools.util.python import pybind_solve_interrupter_testing
from ortools.util.python import solve_interrupter
class SolveInterrupterTest(absltest.TestCase):
def test_untriggered_interrupter(self) -> None:
interrupter = solve_interrupter.SolveInterrupter()
self.assertFalse(interrupter.interrupted)
def test_triggered_interrupter(self) -> None:
interrupter = solve_interrupter.SolveInterrupter()
interrupter.interrupt()
self.assertTrue(interrupter.interrupted)
def test_add_target(self) -> None:
source = solve_interrupter.SolveInterrupter()
target = solve_interrupter.SolveInterrupter()
source.add_trigger_target(target)
source.interrupt()
self.assertTrue(source.interrupted)
self.assertTrue(target.interrupted)
def test_remove_existing_target(self) -> None:
source = solve_interrupter.SolveInterrupter()
target = solve_interrupter.SolveInterrupter()
source.add_trigger_target(target)
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(source.interrupted)
self.assertFalse(target.interrupted)
def test_remove_existing_target_added_twice(self) -> None:
source = solve_interrupter.SolveInterrupter()
target = solve_interrupter.SolveInterrupter()
source.add_trigger_target(target)
source.add_trigger_target(target)
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(source.interrupted)
self.assertFalse(target.interrupted)
def test_dead_target(self) -> None:
source = solve_interrupter.SolveInterrupter()
target = solve_interrupter.SolveInterrupter()
source.add_trigger_target(target)
del target
source.interrupt()
self.assertTrue(source.interrupted)
def test_remove_existing_target_twice(self) -> None:
source = solve_interrupter.SolveInterrupter()
target = solve_interrupter.SolveInterrupter()
source.add_trigger_target(target)
source.remove_trigger_target(target)
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(source.interrupted)
self.assertFalse(target.interrupted)
def test_remove_non_target(self) -> None:
source = solve_interrupter.SolveInterrupter()
target = solve_interrupter.SolveInterrupter()
source.remove_trigger_target(target)
source.interrupt()
self.assertTrue(source.interrupted)
self.assertFalse(target.interrupted)
def test_callback_already_interrupted(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
interrupter = solve_interrupter.SolveInterrupter()
interrupter.interrupt()
with interrupter.interruption_callback(callback):
self.assertEqual(num_calls, 1)
self.assertEqual(num_calls, 1)
def test_callback_interruption(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
interrupter = solve_interrupter.SolveInterrupter()
with interrupter.interruption_callback(callback):
self.assertEqual(num_calls, 0)
interrupter.interrupt()
self.assertEqual(num_calls, 1)
self.assertEqual(num_calls, 1)
def test_callback_interruption_after_removal(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
interrupter = solve_interrupter.SolveInterrupter()
with interrupter.interruption_callback(callback):
self.assertEqual(num_calls, 0)
interrupter.interrupt()
self.assertEqual(num_calls, 0)
def test_callback_nointerruption(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
interrupter = solve_interrupter.SolveInterrupter()
with interrupter.interruption_callback(callback):
self.assertEqual(num_calls, 0)
self.assertEqual(num_calls, 0)
def test_callback_with_exception_in_callback(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
raise ValueError("error-in-callback")
interrupter = solve_interrupter.SolveInterrupter()
# has_finished is set to true after the call to interrupter.interrupt() and
# will be used to validate that interrupt() does not raise the exception,
# only the __exit__() of interruption_callback() context does.
has_finished = False
with self.assertRaises(solve_interrupter.CallbackError) as cm:
with interrupter.interruption_callback(callback):
before_interrupt_num_calls = num_calls
interrupter.interrupt()
after_interrupt_num_calls = num_calls
has_finished = True
# Test the cause of the CallbackError, which should be the original error.
self.assertIsInstance(cm.exception.__cause__, ValueError)
self.assertEqual(str(cm.exception.__cause__), "error-in-callback")
self.assertEqual(before_interrupt_num_calls, 0)
self.assertEqual(after_interrupt_num_calls, 1)
self.assertTrue(has_finished)
self.assertEqual(num_calls, 1)
def test_callback_with_exception_in_context(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
interrupter = solve_interrupter.SolveInterrupter()
with self.assertRaisesRegex(ValueError, "error-in-context"):
with interrupter.interruption_callback(callback):
raise ValueError("error-in-context")
interrupter.interrupt()
self.assertEqual(num_calls, 0)
def test_callback_with_exception_in_callback_and_context(self) -> None:
num_calls = 0
def callback():
nonlocal num_calls
num_calls += 1
raise ValueError("error-in-callback")
interrupter = solve_interrupter.SolveInterrupter()
with self.assertRaisesRegex(ValueError, "error-in-context"):
with interrupter.interruption_callback(callback):
before_interrupt_num_calls = num_calls
interrupter.interrupt()
after_interrupt_num_calls = num_calls
raise ValueError("error-in-context")
self.assertEqual(before_interrupt_num_calls, 0)
self.assertEqual(after_interrupt_num_calls, 1)
self.assertEqual(num_calls, 1)
def test_pybind_interrupter(self) -> None:
interrupter = solve_interrupter.SolveInterrupter()
pybind_interrupter = interrupter.pybind_interrupter
self.assertFalse(
pybind_solve_interrupter_testing.IsInterrupted(pybind_interrupter)
)
interrupter.interrupt()
self.assertTrue(
pybind_solve_interrupter_testing.IsInterrupted(pybind_interrupter)
)
if __name__ == "__main__":
absltest.main()