util/python: add solve interrupter
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
load("@pip_deps//:requirements.bzl", "requirement")
|
||||
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||
load("@rules_python//python:py_library.bzl", "py_library")
|
||||
load("@rules_python//python:py_test.bzl", "py_test")
|
||||
|
||||
cc_library(
|
||||
@@ -44,3 +45,83 @@ py_test(
|
||||
requirement("absl-py"),
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_solve_interrupter_lib",
|
||||
srcs = ["py_solve_interrupter.cc"],
|
||||
hdrs = ["py_solve_interrupter.h"],
|
||||
# This library is not meant to be consumed end users; only by C++ code of
|
||||
# Clif libraries that needs a SolverInterrupter.
|
||||
visibility = [
|
||||
"//ortools/math_opt/core/python:__subpackages__",
|
||||
"//ortools/math_opt/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//ortools/util:solve_interrupter",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
"@abseil-cpp//absl/base:nullability",
|
||||
"@abseil-cpp//absl/log",
|
||||
"@abseil-cpp//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_solve_interrupter_testing_lib",
|
||||
testonly = True,
|
||||
srcs = ["py_solve_interrupter_testing.cc"],
|
||||
hdrs = ["py_solve_interrupter_testing.h"],
|
||||
deps = [
|
||||
":py_solve_interrupter_lib",
|
||||
"@abseil-cpp//absl/base:nullability",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "pybind_solve_interrupter",
|
||||
srcs = ["pybind_solve_interrupter.cc"],
|
||||
# This library is not meant to be consumed end users; only pybind11 code
|
||||
# that needs a SolverInterrupter.
|
||||
visibility = [
|
||||
"//ortools/math_opt/core/python:__subpackages__",
|
||||
],
|
||||
deps = [":py_solve_interrupter_lib"],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "pybind_solve_interrupter_testing",
|
||||
testonly = True,
|
||||
srcs = ["pybind_solve_interrupter_testing.cc"],
|
||||
deps = [":py_solve_interrupter_testing_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "pybind_solve_interrupter_test",
|
||||
srcs = ["pybind_solve_interrupter_test.py"],
|
||||
deps = [
|
||||
":pybind_solve_interrupter",
|
||||
":pybind_solve_interrupter_testing",
|
||||
requirement("absl-py"),
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "solve_interrupter",
|
||||
srcs = ["solve_interrupter.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":py_solve_interrupter_lib",
|
||||
":pybind_solve_interrupter",
|
||||
requirement("absl-py"),
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "solve_interrupter_test",
|
||||
srcs = ["solve_interrupter_test.py"],
|
||||
deps = [
|
||||
":py_solve_interrupter_testing_lib",
|
||||
":pybind_solve_interrupter_testing",
|
||||
":solve_interrupter",
|
||||
requirement("absl-py"),
|
||||
],
|
||||
)
|
||||
|
||||
91
ortools/util/python/py_solve_interrupter.cc
Normal file
91
ortools/util/python/py_solve_interrupter.cc
Normal file
@@ -0,0 +1,91 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/util/python/py_solve_interrupter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
PySolveInterrupter::PySolveInterrupter()
|
||||
: callback_(&interrupter_, [this]() { TriggerTargets(); }) {}
|
||||
|
||||
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>>
|
||||
PySolveInterrupter::CleanupAndGetTargets(
|
||||
const PySolveInterrupter* absl_nullable const to_remove) {
|
||||
// First get strong references of non-expired targets. Also filter the
|
||||
// to_remove target if found.
|
||||
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>>
|
||||
non_expired_targets;
|
||||
non_expired_targets.reserve(targets_.size());
|
||||
for (std::weak_ptr<PySolveInterrupter>& weak_target : targets_) {
|
||||
std::shared_ptr<PySolveInterrupter> strong_target = weak_target.lock();
|
||||
if (strong_target != nullptr && strong_target.get() != to_remove) {
|
||||
non_expired_targets.push_back(std::move(strong_target));
|
||||
}
|
||||
}
|
||||
|
||||
// Then recreates targets weak-references with only the non-expired targets.
|
||||
//
|
||||
// Note that we could be more efficient by doing the cleanup in-place but here
|
||||
// we keep the code simple. In practice the number of targets_ is expected to
|
||||
// be very low (less than 10).
|
||||
targets_.clear();
|
||||
for (std::shared_ptr<PySolveInterrupter>& strong_target :
|
||||
non_expired_targets) {
|
||||
targets_.push_back(strong_target);
|
||||
}
|
||||
|
||||
return non_expired_targets;
|
||||
}
|
||||
|
||||
void PySolveInterrupter::AddTriggerTarget(
|
||||
absl_nonnull std::shared_ptr<PySolveInterrupter> target) {
|
||||
const absl::MutexLock lock(&mutex_);
|
||||
CleanupAndGetTargets();
|
||||
// Note that we don't test if targets_ already contain the interrupter as
|
||||
// interrupter are triggerable only once. Thus having duplicates won't result
|
||||
// in visible changes outside.
|
||||
//
|
||||
// The way RemoveTriggerTarget() is implemented will remove duplicates as
|
||||
// well.
|
||||
targets_.push_back(std::move(target));
|
||||
}
|
||||
|
||||
void PySolveInterrupter::RemoveTriggerTarget(
|
||||
absl_nonnull std::shared_ptr<PySolveInterrupter> target) {
|
||||
const absl::MutexLock lock(&mutex_);
|
||||
CleanupAndGetTargets(/*to_remove=*/target.get());
|
||||
}
|
||||
|
||||
void PySolveInterrupter::TriggerTargets() {
|
||||
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>> targets;
|
||||
{ // Limit `lock` scope.
|
||||
const absl::MutexLock lock(&mutex_);
|
||||
targets = CleanupAndGetTargets();
|
||||
}
|
||||
|
||||
// Call targets without holding mutex_.
|
||||
for (const absl_nonnull std::shared_ptr<PySolveInterrupter>& target :
|
||||
targets) {
|
||||
target->Interrupt();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
116
ortools/util/python/py_solve_interrupter.h
Normal file
116
ortools/util/python/py_solve_interrupter.h
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_
|
||||
#define OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "ortools/util/solve_interrupter.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
// Simple class that wraps a SolveInterrupter for Clif or pybind11.
|
||||
//
|
||||
class PySolveInterrupter {
|
||||
public:
|
||||
PySolveInterrupter();
|
||||
|
||||
PySolveInterrupter(const PySolveInterrupter&) = delete;
|
||||
PySolveInterrupter& operator=(const PySolveInterrupter&) = delete;
|
||||
|
||||
// Interrupts the solve as soon as possible.
|
||||
void Interrupt() { interrupter_.Interrupt(); }
|
||||
|
||||
// Returns true if the solve interruption has been requested.
|
||||
inline bool IsInterrupted() const { return interrupter_.IsInterrupted(); }
|
||||
|
||||
// Triggers the target when this interrupter is triggered.
|
||||
//
|
||||
// A std::weak_ptr is kept on the target. Expired std::weak_ptr are cleaned up
|
||||
// on calls to AddTriggerTarget() and RemoveTriggerTarget().
|
||||
//
|
||||
// Complexity: O(num_targets).
|
||||
void AddTriggerTarget(absl_nonnull std::shared_ptr<PySolveInterrupter> target)
|
||||
ABSL_LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
// Removes the target if not null and present, else do nothing.
|
||||
//
|
||||
// Complexity: O(num_targets).
|
||||
void RemoveTriggerTarget(
|
||||
absl_nonnull std::shared_ptr<PySolveInterrupter> target)
|
||||
ABSL_LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
// Add a callback on the interrupter and returns an id to use to remove it.
|
||||
//
|
||||
// See SolveInterrupter::AddInterruptionCallback().
|
||||
int64_t AddInterruptionCallback(std::function<void()> callback) const {
|
||||
return interrupter_.AddInterruptionCallback(std::move(callback)).value();
|
||||
}
|
||||
|
||||
// Remove a callback previously registered by AddInterruptionCallback().
|
||||
//
|
||||
// See SolveInterrupter::RemoveInterruptionCallback().
|
||||
void RemoveInterruptionCallback(int64_t callback_id) const {
|
||||
interrupter_.RemoveInterruptionCallback(
|
||||
SolveInterrupter::CallbackId{callback_id});
|
||||
}
|
||||
|
||||
// Return the underlying interrupter. This method is not exposed in Python and
|
||||
// only available for C++ code.
|
||||
//
|
||||
// The lifetime of the PySolveInterrupter is controlled by Python. To prevent
|
||||
// issues where the underlying SolveInterrupter would be destroyed while still
|
||||
// being pointed to by C++ code, C++ Clif/pybind11 consumer code should take
|
||||
// references to PySolveInterrupter in an std::shared_ptr that outlives any
|
||||
// reference.
|
||||
const SolveInterrupter* absl_nonnull interrupter() const {
|
||||
return &interrupter_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Remove expired targets_, the optional to_remove target and return strong
|
||||
// references on non-expired targets.
|
||||
//
|
||||
// The caller must have acquired mutex_.
|
||||
//
|
||||
// Complexity: O(num_targets).
|
||||
std::vector<absl_nonnull std::shared_ptr<PySolveInterrupter>>
|
||||
CleanupAndGetTargets(const PySolveInterrupter* absl_nullable to_remove =
|
||||
nullptr) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||
|
||||
// Triggers all non-expired targets_ interrupters.
|
||||
void TriggerTargets() ABSL_LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
SolveInterrupter interrupter_;
|
||||
|
||||
absl::Mutex mutex_;
|
||||
|
||||
// Interrupters to trigger when interrupter_ is triggered.
|
||||
std::vector<std::weak_ptr<PySolveInterrupter>> targets_
|
||||
ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
// Callback that will trigger all interrupters in target_.
|
||||
//
|
||||
// It MUST appear after targets_ and interrupter_ as we want to make sure it
|
||||
// is destroyed first!
|
||||
ScopedSolveInterrupterCallback callback_;
|
||||
};
|
||||
|
||||
} // namespace operations_research
|
||||
|
||||
#endif // OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_H_
|
||||
29
ortools/util/python/py_solve_interrupter_testing.cc
Normal file
29
ortools/util/python/py_solve_interrupter_testing.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/util/python/py_solve_interrupter_testing.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "ortools/util/python/py_solve_interrupter.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
std::optional<bool> IsInterrupted(const PySolveInterrupter* interrupter) {
|
||||
if (interrupter == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return interrupter->IsInterrupted();
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
62
ortools/util/python/py_solve_interrupter_testing.h
Normal file
62
ortools/util/python/py_solve_interrupter_testing.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Library to unit test PySolveInterrupter wrapper.
|
||||
#ifndef OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_
|
||||
#define OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "ortools/util/python/py_solve_interrupter.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
// Returns:
|
||||
// * nullopt if `interrupter` is nullptr,
|
||||
// * or false if `interrupter` is not nullptr and is not interrupted,
|
||||
// * or true if `interrupted` is not nullptr and is interrupted.
|
||||
//
|
||||
// The Clif/pybind11 wrapper will return a `bool | None` value, with None for
|
||||
// nullopt.
|
||||
std::optional<bool> IsInterrupted(const PySolveInterrupter* interrupter);
|
||||
|
||||
// Class that keeps a reference on a std::shared_ptr<PySolveInterrupter> to
|
||||
// test that the C++ object survive the cleanup of the Python reference.
|
||||
class PySolveInterrupterReference {
|
||||
public:
|
||||
explicit PySolveInterrupterReference(
|
||||
absl_nonnull std::shared_ptr<PySolveInterrupter> interrupter)
|
||||
: interrupter_(std::move(interrupter)) {}
|
||||
|
||||
PySolveInterrupterReference(const PySolveInterrupterReference&) = delete;
|
||||
PySolveInterrupterReference& operator=(const PySolveInterrupterReference&) =
|
||||
delete;
|
||||
|
||||
// Returns the std::shared_ptr<PySolveInterrupter>::use_count().
|
||||
//
|
||||
// This is used to test that Python has stopped pointing to the object.
|
||||
int use_count() const { return interrupter_.use_count(); }
|
||||
|
||||
// Returns true if the underlying interrupter is interrupted.
|
||||
bool is_interrupted() const { return interrupter_->IsInterrupted(); }
|
||||
|
||||
private:
|
||||
const absl_nonnull std::shared_ptr<PySolveInterrupter> interrupter_;
|
||||
};
|
||||
|
||||
} // namespace operations_research
|
||||
|
||||
#endif // OR_TOOLS_UTIL_PYTHON_PY_SOLVE_INTERRUPTER_TESTING_H_
|
||||
41
ortools/util/python/pybind_solve_interrupter.cc
Normal file
41
ortools/util/python/pybind_solve_interrupter.cc
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "ortools/util/python/py_solve_interrupter.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
PYBIND11_MODULE(pybind_solve_interrupter, m) {
|
||||
py::class_<PySolveInterrupter, std::shared_ptr<PySolveInterrupter>>(
|
||||
m, "PySolveInterrupter")
|
||||
.def(py::init())
|
||||
.def("interrupt", &PySolveInterrupter::Interrupt)
|
||||
.def_property_readonly("interrupted", &PySolveInterrupter::IsInterrupted)
|
||||
.def("add_trigger_target", &PySolveInterrupter::AddTriggerTarget,
|
||||
py::arg("target"))
|
||||
.def("remove_trigger_target", &PySolveInterrupter::RemoveTriggerTarget,
|
||||
py::arg("target"))
|
||||
.def("add_interruption_callback",
|
||||
&PySolveInterrupter::AddInterruptionCallback, py::arg("callback"))
|
||||
.def("remove_interruption_callback",
|
||||
&PySolveInterrupter::RemoveInterruptionCallback,
|
||||
py::arg("callback_id"));
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
139
ortools/util/python/pybind_solve_interrupter_test.py
Normal file
139
ortools/util/python/pybind_solve_interrupter_test.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2010-2025 Google LLC
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from ortools.util.python import pybind_solve_interrupter
|
||||
from ortools.util.python import pybind_solve_interrupter_testing
|
||||
|
||||
|
||||
class PybindPySolveInterrupterTest(absltest.TestCase):
|
||||
"""Test pybind_solve_interrupter.PySolveInterrupter class.
|
||||
|
||||
This tests how pybind_solve_interrupter.PySolveInterrupter is accepted by
|
||||
other
|
||||
pybind-generated APIs in Python, using pybind_solve_interrupter_testing to do
|
||||
so.
|
||||
"""
|
||||
|
||||
def test_none_interrupter(self) -> None:
|
||||
"""Test that a `const PySolveInterrupter*` can receive nullptr."""
|
||||
# Here IsInterrupted() is a Clif wrapped C++ function that expects a `const
|
||||
# PySolveInterrupter*`. It returns None (std::nullopt) for nullptr
|
||||
# input.
|
||||
self.assertIsNone(pybind_solve_interrupter_testing.IsInterrupted(None))
|
||||
|
||||
def test_untriggered_interrupter(self) -> None:
|
||||
"""Test that an untriggered interrupter is properly passed."""
|
||||
interrupter = pybind_solve_interrupter.PySolveInterrupter()
|
||||
# Test the getter.
|
||||
self.assertFalse(interrupter.interrupted)
|
||||
# Test using a Clif wrapped C++ function. We test it is not None to make
|
||||
# sure Clif passes a non-null pointer to the C++ code. We thus have to use
|
||||
# assertIs() instead of assertFalse().
|
||||
self.assertIs(
|
||||
pybind_solve_interrupter_testing.IsInterrupted(interrupter), False
|
||||
)
|
||||
|
||||
def test_triggered_interrupter(self) -> None:
|
||||
"""Test that an untriggered interrupter is properly passed."""
|
||||
interrupter = pybind_solve_interrupter.PySolveInterrupter()
|
||||
interrupter.interrupt()
|
||||
# Test the getter.
|
||||
self.assertTrue(interrupter.interrupted)
|
||||
# Test using a Clif wrapped C++ function.
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(interrupter))
|
||||
|
||||
def test_shared_reference(self) -> None:
|
||||
"""Test that taking a std::shared_ptr<PySolveInterrupter> works."""
|
||||
# Create an interrupter and a PySolveInterrupterReference class which
|
||||
# constructor takes and std::shared_ptr<PySolveInterrupter>. We expect
|
||||
# that only one instance of the C++ PySolveInterrupter will exist here.
|
||||
interrupter = pybind_solve_interrupter.PySolveInterrupter()
|
||||
interrupter_ref = pybind_solve_interrupter_testing.PySolveInterrupterReference(
|
||||
interrupter
|
||||
)
|
||||
|
||||
# Validate that we have the expected number of references of the shared_ptr
|
||||
# hold by PySolveInterrupterReference.
|
||||
self.assertEqual(interrupter_ref.use_count, 2)
|
||||
self.assertFalse(interrupter_ref.is_interrupted)
|
||||
|
||||
# Triggering the interrupter should be visible the pointed
|
||||
# PySolveInterrupter in C++.
|
||||
interrupter.interrupt()
|
||||
self.assertEqual(interrupter_ref.use_count, 2)
|
||||
self.assertTrue(interrupter_ref.is_interrupted)
|
||||
|
||||
# Removing the Python `interrupter` reference should make `interrupter_ref`
|
||||
# the only object that holds an std::shared_ptr on the interrupter.
|
||||
del interrupter
|
||||
self.assertEqual(interrupter_ref.use_count, 1)
|
||||
self.assertTrue(interrupter_ref.is_interrupted)
|
||||
|
||||
def test_add_target(self) -> None:
|
||||
source = pybind_solve_interrupter.PySolveInterrupter()
|
||||
target = pybind_solve_interrupter.PySolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(target))
|
||||
|
||||
def test_remove_existing_target(self) -> None:
|
||||
source = pybind_solve_interrupter.PySolveInterrupter()
|
||||
target = pybind_solve_interrupter.PySolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
|
||||
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
|
||||
|
||||
def test_remove_existing_target_added_twice(self) -> None:
|
||||
source = pybind_solve_interrupter.PySolveInterrupter()
|
||||
target = pybind_solve_interrupter.PySolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.add_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
|
||||
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
|
||||
|
||||
def test_dead_target(self) -> None:
|
||||
source = pybind_solve_interrupter.PySolveInterrupter()
|
||||
target = pybind_solve_interrupter.PySolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
del target
|
||||
source.interrupt()
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
|
||||
|
||||
def test_remove_existing_target_twice(self) -> None:
|
||||
source = pybind_solve_interrupter.PySolveInterrupter()
|
||||
target = pybind_solve_interrupter.PySolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
|
||||
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
|
||||
|
||||
def test_remove_non_target(self) -> None:
|
||||
source = pybind_solve_interrupter.PySolveInterrupter()
|
||||
target = pybind_solve_interrupter.PySolveInterrupter()
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(pybind_solve_interrupter_testing.IsInterrupted(source))
|
||||
self.assertFalse(pybind_solve_interrupter_testing.IsInterrupted(target))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
34
ortools/util/python/pybind_solve_interrupter_testing.cc
Normal file
34
ortools/util/python/pybind_solve_interrupter_testing.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "ortools/util/python/py_solve_interrupter_testing.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
PYBIND11_MODULE(pybind_solve_interrupter_testing, m) {
|
||||
py::class_<PySolveInterrupterReference>(m, "PySolveInterrupterReference")
|
||||
.def(py::init<std::shared_ptr<PySolveInterrupter>>())
|
||||
.def_property_readonly("use_count",
|
||||
&PySolveInterrupterReference::use_count)
|
||||
.def_property_readonly("is_interrupted",
|
||||
&PySolveInterrupterReference::is_interrupted);
|
||||
m.def("IsInterrupted", &IsInterrupted, py::arg("interrupter"));
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
136
ortools/util/python/solve_interrupter.py
Normal file
136
ortools/util/python/solve_interrupter.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright 2010-2025 Google LLC
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Python interrupter for solves."""
|
||||
|
||||
from collections.abc import Callable, Iterator
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
from absl import logging
|
||||
|
||||
from ortools.util.python import pybind_solve_interrupter
|
||||
|
||||
|
||||
class CallbackError(Exception):
|
||||
"""Exception raised when an interrupter callback fails.
|
||||
|
||||
When using SolveInterrupter.interruption_callback(), this exception is raised
|
||||
when exiting the context manager if the callback failed. The error in the
|
||||
callback is the cause of this exception.
|
||||
"""
|
||||
|
||||
|
||||
class SolveInterrupter:
|
||||
"""Interrupter used by solvers to know when they should interrupt the solve.
|
||||
|
||||
Once triggered with interrupt(), an interrupter can't be reset. It can be
|
||||
triggered from any thread.
|
||||
|
||||
Thread-safety: APIs on this class are safe to call concurrently from multiple
|
||||
threads.
|
||||
|
||||
Attributes:
|
||||
pybind_interrupter: The pybind wrapper around PySolveInterrupter.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pybind_interrupter = pybind_solve_interrupter.PySolveInterrupter()
|
||||
|
||||
def interrupt(self) -> None:
|
||||
"""Interrupts the solve as soon as possible.
|
||||
|
||||
Once requested the interruption can't be reset. The user should use a new
|
||||
SolveInterrupter for later solves.
|
||||
|
||||
It is safe to call this function multiple times. Only the first call will
|
||||
have visible effects; other calls will be ignored.
|
||||
"""
|
||||
self.pybind_interrupter.interrupt()
|
||||
|
||||
@property
|
||||
def interrupted(self) -> bool:
|
||||
"""True if the solve interruption has been requested."""
|
||||
return self.pybind_interrupter.interrupted
|
||||
|
||||
def add_trigger_target(self, target: "SolveInterrupter") -> None:
|
||||
"""Triggers the target when this interrupter is triggered."""
|
||||
self.pybind_interrupter.add_trigger_target(target.pybind_interrupter)
|
||||
|
||||
def remove_trigger_target(self, target: "SolveInterrupter") -> None:
|
||||
"""Removes the target if not null and present, else do nothing."""
|
||||
self.pybind_interrupter.remove_trigger_target(target.pybind_interrupter)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def interruption_callback(self, callback: Callable[[], None]) -> Iterator[None]:
|
||||
"""Returns a context manager that (un)register the provided callback.
|
||||
|
||||
The callback is immediately called if the interrupter has already been
|
||||
triggered or if it is triggered during the registration. This is typically
|
||||
useful for a solver implementation so that it does not have to test
|
||||
`interrupted` to do the same thing it does in the callback. Simply
|
||||
registering the callback is enough.
|
||||
|
||||
The callback function can't make calls to interruption_callback(), and
|
||||
interrupt(). This would result is a deadlock. Reading `interrupted` is fine
|
||||
though.
|
||||
|
||||
Exceptions raised in the callback are raised on exit from the context
|
||||
manager if no other error happens within the context. Else the exception is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
callback: The callback.
|
||||
|
||||
Returns:
|
||||
A context manager.
|
||||
|
||||
Raises:
|
||||
CallbackError: When exiting the context manager if an exception was raised
|
||||
in the callback.
|
||||
"""
|
||||
callback_error: Optional[Exception] = None
|
||||
|
||||
def protetected_callback():
|
||||
"""Calls callback() storing any exception in callback_error."""
|
||||
nonlocal callback_error
|
||||
try:
|
||||
callback()
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
# It is fine to set callback_error without any threading protection as
|
||||
# the SolveInterrupter guarantees it will only call it at most once.
|
||||
callback_error = e
|
||||
|
||||
callback_id = self.pybind_interrupter.add_interruption_callback(
|
||||
protetected_callback
|
||||
)
|
||||
|
||||
no_exception_in_context = False
|
||||
try:
|
||||
yield
|
||||
no_exception_in_context = True
|
||||
finally:
|
||||
self.pybind_interrupter.remove_interruption_callback(callback_id)
|
||||
# It is fine to access callback_error without threading protection after
|
||||
# remove_interruption_callback() has returned as the SolveInterrupter
|
||||
# guarantees any pending call is done and no future call can happen.
|
||||
if callback_error is not None:
|
||||
if no_exception_in_context:
|
||||
raise CallbackError() from callback_error
|
||||
# We don't want the error in the context to be masked by an error in the
|
||||
# callback. We log it instead.
|
||||
logging.error(
|
||||
"An exception occurred in callback but is masked by another"
|
||||
" exception: %s",
|
||||
repr(callback_error),
|
||||
)
|
||||
226
ortools/util/python/solve_interrupter_test.py
Normal file
226
ortools/util/python/solve_interrupter_test.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2010-2025 Google LLC
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from ortools.util.python import pybind_solve_interrupter_testing
|
||||
from ortools.util.python import solve_interrupter
|
||||
|
||||
|
||||
class SolveInterrupterTest(absltest.TestCase):
|
||||
|
||||
def test_untriggered_interrupter(self) -> None:
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
self.assertFalse(interrupter.interrupted)
|
||||
|
||||
def test_triggered_interrupter(self) -> None:
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
interrupter.interrupt()
|
||||
self.assertTrue(interrupter.interrupted)
|
||||
|
||||
def test_add_target(self) -> None:
|
||||
source = solve_interrupter.SolveInterrupter()
|
||||
target = solve_interrupter.SolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(source.interrupted)
|
||||
self.assertTrue(target.interrupted)
|
||||
|
||||
def test_remove_existing_target(self) -> None:
|
||||
source = solve_interrupter.SolveInterrupter()
|
||||
target = solve_interrupter.SolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(source.interrupted)
|
||||
self.assertFalse(target.interrupted)
|
||||
|
||||
def test_remove_existing_target_added_twice(self) -> None:
|
||||
source = solve_interrupter.SolveInterrupter()
|
||||
target = solve_interrupter.SolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.add_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(source.interrupted)
|
||||
self.assertFalse(target.interrupted)
|
||||
|
||||
def test_dead_target(self) -> None:
|
||||
source = solve_interrupter.SolveInterrupter()
|
||||
target = solve_interrupter.SolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
del target
|
||||
source.interrupt()
|
||||
self.assertTrue(source.interrupted)
|
||||
|
||||
def test_remove_existing_target_twice(self) -> None:
|
||||
source = solve_interrupter.SolveInterrupter()
|
||||
target = solve_interrupter.SolveInterrupter()
|
||||
source.add_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(source.interrupted)
|
||||
self.assertFalse(target.interrupted)
|
||||
|
||||
def test_remove_non_target(self) -> None:
|
||||
source = solve_interrupter.SolveInterrupter()
|
||||
target = solve_interrupter.SolveInterrupter()
|
||||
source.remove_trigger_target(target)
|
||||
source.interrupt()
|
||||
self.assertTrue(source.interrupted)
|
||||
self.assertFalse(target.interrupted)
|
||||
|
||||
def test_callback_already_interrupted(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
interrupter.interrupt()
|
||||
|
||||
with interrupter.interruption_callback(callback):
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
def test_callback_interruption(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
|
||||
with interrupter.interruption_callback(callback):
|
||||
self.assertEqual(num_calls, 0)
|
||||
interrupter.interrupt()
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
def test_callback_interruption_after_removal(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
|
||||
with interrupter.interruption_callback(callback):
|
||||
self.assertEqual(num_calls, 0)
|
||||
|
||||
interrupter.interrupt()
|
||||
self.assertEqual(num_calls, 0)
|
||||
|
||||
def test_callback_nointerruption(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
|
||||
with interrupter.interruption_callback(callback):
|
||||
self.assertEqual(num_calls, 0)
|
||||
|
||||
self.assertEqual(num_calls, 0)
|
||||
|
||||
def test_callback_with_exception_in_callback(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
raise ValueError("error-in-callback")
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
|
||||
# has_finished is set to true after the call to interrupter.interrupt() and
|
||||
# will be used to validate that interrupt() does not raise the exception,
|
||||
# only the __exit__() of interruption_callback() context does.
|
||||
has_finished = False
|
||||
with self.assertRaises(solve_interrupter.CallbackError) as cm:
|
||||
with interrupter.interruption_callback(callback):
|
||||
before_interrupt_num_calls = num_calls
|
||||
interrupter.interrupt()
|
||||
after_interrupt_num_calls = num_calls
|
||||
has_finished = True
|
||||
# Test the cause of the CallbackError, which should be the original error.
|
||||
self.assertIsInstance(cm.exception.__cause__, ValueError)
|
||||
self.assertEqual(str(cm.exception.__cause__), "error-in-callback")
|
||||
|
||||
self.assertEqual(before_interrupt_num_calls, 0)
|
||||
self.assertEqual(after_interrupt_num_calls, 1)
|
||||
self.assertTrue(has_finished)
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
def test_callback_with_exception_in_context(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "error-in-context"):
|
||||
with interrupter.interruption_callback(callback):
|
||||
raise ValueError("error-in-context")
|
||||
|
||||
interrupter.interrupt()
|
||||
self.assertEqual(num_calls, 0)
|
||||
|
||||
def test_callback_with_exception_in_callback_and_context(self) -> None:
|
||||
num_calls = 0
|
||||
|
||||
def callback():
|
||||
nonlocal num_calls
|
||||
num_calls += 1
|
||||
raise ValueError("error-in-callback")
|
||||
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "error-in-context"):
|
||||
with interrupter.interruption_callback(callback):
|
||||
before_interrupt_num_calls = num_calls
|
||||
interrupter.interrupt()
|
||||
after_interrupt_num_calls = num_calls
|
||||
raise ValueError("error-in-context")
|
||||
|
||||
self.assertEqual(before_interrupt_num_calls, 0)
|
||||
self.assertEqual(after_interrupt_num_calls, 1)
|
||||
self.assertEqual(num_calls, 1)
|
||||
|
||||
def test_pybind_interrupter(self) -> None:
|
||||
interrupter = solve_interrupter.SolveInterrupter()
|
||||
pybind_interrupter = interrupter.pybind_interrupter
|
||||
|
||||
self.assertFalse(
|
||||
pybind_solve_interrupter_testing.IsInterrupted(pybind_interrupter)
|
||||
)
|
||||
|
||||
interrupter.interrupt()
|
||||
|
||||
self.assertTrue(
|
||||
pybind_solve_interrupter_testing.IsInterrupted(pybind_interrupter)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user