[CP-SAT] fix presolve bug; fix callback bug

This commit is contained in:
Laurent Perron
2024-09-23 15:28:18 +02:00
parent f3f8830ccb
commit f053be9786
10 changed files with 385 additions and 91 deletions

View File

@@ -3004,6 +3004,7 @@ cc_library(
":cp_model_utils",
":model",
":sat_parameters_cc_proto",
":util",
"//ortools/util:logging",
"//ortools/util:sorted_interval_list",
"//ortools/util:time_limit",

View File

@@ -172,6 +172,13 @@ class CpModelMapping {
return reverse_integer_map_[var];
}
// This one should only be used when we have a mapping.
int GetProtoLiteralFromLiteral(sat::Literal lit) const {
const int proto_var = GetProtoVariableFromBooleanVariable(lit.Variable());
DCHECK_NE(proto_var, -1);
return lit.IsPositive() ? proto_var : NegatedRef(proto_var);
}
const std::vector<IntegerVariable>& GetVariableMapping() const {
return integers_;
}

View File

@@ -2508,7 +2508,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) {
context_->UpdateRuleStats("linear1: infeasible");
return MarkConstraintAsFalse(ct);
}
if (rhs == context_->DomainOf(var)) {
if (rhs == var_domain) {
context_->UpdateRuleStats("linear1: always true");
return RemoveConstraint(ct);
}
@@ -2544,16 +2544,28 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) {
}
// Detect encoding.
bool changed = false;
if (ct->enforcement_literal().size() == 1) {
// If we already have an encoding literal, this constraint is really
// an implication.
const int lit = ct->enforcement_literal(0);
int lit = ct->enforcement_literal(0);
// For correctness below, it is important lit is the canonical literal,
// otherwise we might remove the constraint even though it is the one
// defining an encoding literal.
const int representative = context_->GetLiteralRepresentative(lit);
if (lit != representative) {
lit = representative;
ct->set_enforcement_literal(0, lit);
context_->UpdateRuleStats("linear1: remapped enforcement literal");
changed = true;
}
if (rhs.IsFixed()) {
const int64_t value = rhs.FixedValue();
int encoding_lit;
if (context_->HasVarValueEncoding(var, value, &encoding_lit)) {
if (lit == encoding_lit) return false;
if (lit == encoding_lit) return changed;
context_->AddImplication(lit, encoding_lit);
context_->UpdateNewConstraintsVariableUsage();
ct->Clear();
@@ -2567,7 +2579,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) {
}
context_->UpdateNewConstraintsVariableUsage();
}
return false;
return changed;
}
const Domain complement = rhs.Complement().IntersectionWith(var_domain);
@@ -2575,7 +2587,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) {
const int64_t value = complement.FixedValue();
int encoding_lit;
if (context_->HasVarValueEncoding(var, value, &encoding_lit)) {
if (NegatedRef(lit) == encoding_lit) return false;
if (NegatedRef(lit) == encoding_lit) return changed;
context_->AddImplication(lit, NegatedRef(encoding_lit));
context_->UpdateNewConstraintsVariableUsage();
ct->Clear();
@@ -2589,11 +2601,11 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) {
}
context_->UpdateNewConstraintsVariableUsage();
}
return false;
return changed;
}
}
return false;
return changed;
}
bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) {
@@ -7110,9 +7122,6 @@ void CpModelPresolver::Probe() {
}
probing_timer->AddCounter("fixed_bools", num_fixed);
DetectDuplicateConstraintsWithDifferentEnforcements(
mapping, implication_graph, model.GetOrCreate<Trail>());
int num_equiv = 0;
int num_changed_bounds = 0;
const int num_variables = context_->working_model->variables().size();
@@ -7148,6 +7157,12 @@ void CpModelPresolver::Probe() {
probing_timer->AddCounter("new_binary_clauses",
prober->num_new_binary_clauses());
// Note that we prefer to run this after we exported all equivalence to the
// context, so that our enforcement list can be presolved to the best of our
// knowledge.
DetectDuplicateConstraintsWithDifferentEnforcements(
mapping, implication_graph, model.GetOrCreate<Trail>());
// Stop probing timer now and display info.
probing_timer.reset();
@@ -8888,37 +8903,20 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements(
for (const auto& [dup, rep] : duplicates_without_enforcement) {
auto* dup_ct = context_->working_model->mutable_constraints(dup);
auto* rep_ct = context_->working_model->mutable_constraints(rep);
if (rep_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) {
continue;
// Make sure our enforcement list are up to date: nothing fixed and that
// its uses the literal representatives.
if (PresolveEnforcementLiteral(dup_ct)) {
context_->UpdateConstraintVariableUsage(dup);
}
if (PresolveEnforcementLiteral(rep_ct)) {
context_->UpdateConstraintVariableUsage(rep);
}
// If we have a trail, we can check if any variable of the enforcement is
// fixed to false. This is useful for what follows since calling
// implication_graph->DirectImplications() is invalid for fixed variables.
if (trail != nullptr) {
bool found_false_enforcement = false;
for (const int c : {dup, rep}) {
for (const int l :
context_->working_model->constraints(c).enforcement_literal()) {
if (trail->Assignment().LiteralIsFalse(mapping->Literal(l))) {
found_false_enforcement = true;
break;
}
}
if (found_false_enforcement) {
context_->UpdateRuleStats("enforcement: false literal");
if (c == rep) {
rep_ct->Swap(dup_ct);
context_->UpdateConstraintVariableUsage(rep);
}
dup_ct->Clear();
context_->UpdateConstraintVariableUsage(dup);
break;
}
}
if (found_false_enforcement) {
continue;
}
// Skip this pair if one of the constraint was simplified
if (rep_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET ||
dup_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) {
continue;
}
// If one of them has no enforcement, then the other can be ignored.
@@ -8936,10 +8934,7 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements(
// Special case. This looks specific but users might reify with a cost
// a duplicate constraint. In this case, no need to have two variables,
// we can make them equal by duality argument.
const int a = rep_ct->enforcement_literal(0);
const int b = dup_ct->enforcement_literal(0);
if (context_->IsFixed(a) || context_->IsFixed(b)) continue;
//
// TODO(user): Deal with more general situation? Note that we already
// do something similar in dual_bound_strengthening.Strengthen() were we
// are more general as we just require an unique blocking constraint rather
@@ -8949,6 +8944,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements(
// we can also add the equality. Alternatively, we can just introduce a new
// variable and merge all duplicate constraint into 1 + bunch of boolean
// constraints liking enforcements.
const int a = rep_ct->enforcement_literal(0);
const int b = dup_ct->enforcement_literal(0);
if (context_->VariableWithCostIsUniqueAndRemovable(a) &&
context_->VariableWithCostIsUniqueAndRemovable(b)) {
// Both these case should be presolved before, but it is easy to deal with
@@ -9007,19 +9004,19 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements(
// B, then constraint A is redundant and we can remove it.
const int c_a = i == 0 ? dup : rep;
const int c_b = i == 0 ? rep : dup;
const auto& ct_a = context_->working_model->constraints(c_a);
const auto& ct_b = context_->working_model->constraints(c_b);
enforcement_vars.clear();
implications_used.clear();
for (const int proto_lit :
context_->working_model->constraints(c_b).enforcement_literal()) {
for (const int proto_lit : ct_b.enforcement_literal()) {
const Literal lit = mapping->Literal(proto_lit);
if (trail->Assignment().LiteralIsTrue(lit)) continue;
DCHECK(!trail->Assignment().LiteralIsAssigned(lit));
enforcement_vars.insert(lit);
}
for (const int proto_lit :
context_->working_model->constraints(c_a).enforcement_literal()) {
for (const int proto_lit : ct_a.enforcement_literal()) {
const Literal lit = mapping->Literal(proto_lit);
if (trail->Assignment().LiteralIsTrue(lit)) continue;
DCHECK(!trail->Assignment().LiteralIsAssigned(lit));
for (const Literal implication_lit :
implication_graph->DirectImplications(lit)) {
auto extracted = enforcement_vars.extract(implication_lit);
@@ -9029,6 +9026,71 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements(
}
}
if (enforcement_vars.empty()) {
// Tricky: Because we keep track of literal <=> var == value, we
// cannot easily simplify linear1 here. This is because a scenario
// like this can happen:
//
// We have registered the fact that a <=> X=1 because we saw two
// constraints a => X=1 and not(a) => X!= 1
//
// Now, we are here and we have:
// a => X=1, b => X=1, a => b
// So we rewrite this as
// a => b, b => X=1
//
// But later, the PresolveLinearOfSizeOne() see
// b => X=1 and just rewrite this as b => a since (a <=> X=1).
// This is wrong because the constraint "b => X=1" is needed for the
// equivalence (a <=> X=1), but we lost that fact.
//
// Note(user): In the scenario above we can see that a <=> b, and if
// we know that fact, then the transformation is correctly handled.
// The bug was triggered when the Probing finished early due to time
// limit and we never detected that equivalence.
//
// TODO(user): Try to find a cleaner way to handle this. We could
// query our HasVarValueEncoding() directly here and directly detect a
// <=> b. However we also need to figure the case of
// half-implications.
{
if (ct_a.constraint_case() == ConstraintProto::kLinear &&
ct_a.linear().vars().size() == 1 &&
ct_a.enforcement_literal().size() == 1) {
const int var = ct_a.linear().vars(0);
const Domain var_domain = context_->DomainOf(var);
const Domain rhs =
ReadDomainFromProto(ct_a.linear())
.InverseMultiplicationBy(ct_a.linear().coeffs(0))
.IntersectionWith(var_domain);
// IsFixed() do not work on empty domain.
if (rhs.IsEmpty()) {
context_->UpdateRuleStats("duplicate: linear1 infeasible");
if (!MarkConstraintAsFalse(rep_ct)) return;
if (!MarkConstraintAsFalse(dup_ct)) return;
context_->UpdateConstraintVariableUsage(rep);
context_->UpdateConstraintVariableUsage(dup);
continue;
}
if (rhs == var_domain) {
context_->UpdateRuleStats("duplicate: linear1 always true");
rep_ct->Clear();
dup_ct->Clear();
context_->UpdateConstraintVariableUsage(rep);
context_->UpdateConstraintVariableUsage(dup);
continue;
}
// We skip if it is a var == value or var != value constraint.
if (rhs.IsFixed() ||
rhs.Complement().IntersectionWith(var_domain).IsFixed()) {
context_->UpdateRuleStats(
"TODO duplicate: skipped identical encoding constraints");
continue;
}
}
}
context_->UpdateRuleStats(
"duplicate: identical constraint with implied enforcements");
if (c_a == rep) {
@@ -9043,12 +9105,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements(
// graph. This is because in some case the implications are only true
// in the presence of the "duplicated" constraints.
for (const auto& [a, b] : implications_used) {
const int var_a =
mapping->GetProtoVariableFromBooleanVariable(a.Variable());
const int proto_lit_a = a.IsPositive() ? var_a : NegatedRef(var_a);
const int var_b =
mapping->GetProtoVariableFromBooleanVariable(b.Variable());
const int proto_lit_b = b.IsPositive() ? var_b : NegatedRef(var_b);
const int proto_lit_a = mapping->GetProtoLiteralFromLiteral(a);
const int proto_lit_b = mapping->GetProtoLiteralFromLiteral(b);
context_->AddImplication(proto_lit_a, proto_lit_b);
}
context_->UpdateNewConstraintsVariableUsage();

View File

@@ -13,6 +13,7 @@
#include "ortools/sat/feasibility_jump.h"
#include <cstdint>
#include <utility>
#include "gtest/gtest.h"

View File

@@ -40,6 +40,7 @@
#include "ortools/sat/sat_parameters.pb.h"
#include "ortools/sat/sat_solver.h"
#include "ortools/sat/synchronization.h"
#include "ortools/sat/util.h"
#include "ortools/util/saturated_arithmetic.h"
#include "ortools/util/sorted_interval_list.h"
#include "ortools/util/strong_integers.h"
@@ -610,11 +611,11 @@ bool FeasibilityPump::PropagationRounding() {
}
const int64_t rounded_value =
static_cast<int64_t>(std::round(lp_solution_[var_index]));
SafeDoubleToInt64(std::round(lp_solution_[var_index]));
const int64_t floor_value =
static_cast<int64_t>(std::floor(lp_solution_[var_index]));
SafeDoubleToInt64(std::floor(lp_solution_[var_index]));
const int64_t ceil_value =
static_cast<int64_t>(std::ceil(lp_solution_[var_index]));
SafeDoubleToInt64(std::ceil(lp_solution_[var_index]));
const bool floor_is_in_domain =
(domain.Contains(floor_value) && lb.value() <= floor_value);

View File

@@ -1371,8 +1371,9 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) {
max_literal = max_it->second.Get(this);
if (min_literal != NegatedRef(max_literal)) {
UpdateRuleStats("variables with 2 values: merge encoding literals");
StoreBooleanEqualityRelation(min_literal, NegatedRef(max_literal));
if (is_unsat_) return;
if (!StoreBooleanEqualityRelation(min_literal, NegatedRef(max_literal))) {
return;
}
}
min_literal = GetLiteralRepresentative(min_literal);
max_literal = GetLiteralRepresentative(max_literal);
@@ -1419,7 +1420,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) {
}
}
void PresolveContext::InsertVarValueEncodingInternal(int literal, int var,
bool PresolveContext::InsertVarValueEncodingInternal(int literal, int var,
int64_t value,
bool add_constraints) {
DCHECK(RefIsPositive(var));
@@ -1446,10 +1447,12 @@ void PresolveContext::InsertVarValueEncodingInternal(int literal, int var,
if (literal != previous_literal) {
UpdateRuleStats(
"variables: merge equivalent var value encoding literals");
StoreBooleanEqualityRelation(literal, previous_literal);
if (!StoreBooleanEqualityRelation(literal, previous_literal)) {
return false;
}
}
}
return;
return true;
}
if (DomainOf(var).Size() == 2) {
@@ -1461,6 +1464,9 @@ void PresolveContext::InsertVarValueEncodingInternal(int literal, int var,
AddImplyInDomain(literal, var, Domain(value));
AddImplyInDomain(NegatedRef(literal), var, Domain(value).Complement());
}
// The canonicalization might have proven UNSAT.
return !ModelIsUnsat();
}
bool PresolveContext::InsertHalfVarValueEncoding(int literal, int var,
@@ -1484,8 +1490,10 @@ bool PresolveContext::InsertHalfVarValueEncoding(int literal, int var,
if (other_set.contains({NegatedRef(literal), var, value})) {
UpdateRuleStats("variables: detect fully reified value encoding");
const int imply_eq_literal = imply_eq ? literal : NegatedRef(literal);
InsertVarValueEncodingInternal(imply_eq_literal, var, value,
/*add_constraints=*/false);
if (!InsertVarValueEncodingInternal(imply_eq_literal, var, value,
/*add_constraints=*/false)) {
return false;
}
}
return true;
@@ -1505,7 +1513,10 @@ bool PresolveContext::InsertVarValueEncoding(int literal, int var,
return SetLiteralToFalse(literal);
}
literal = GetLiteralRepresentative(literal);
InsertVarValueEncodingInternal(literal, var, value, /*add_constraints=*/true);
if (!InsertVarValueEncodingInternal(literal, var, value,
/*add_constraints=*/true)) {
return false;
}
eq_half_encoding_.insert({literal, var, value});
neq_half_encoding_.insert({NegatedRef(literal), var, value});

View File

@@ -664,7 +664,8 @@ class PresolveContext {
bool imply_eq);
// Insert fully reified var-value encoding.
void InsertVarValueEncodingInternal(int literal, int var, int64_t value,
// Returns false if this make the problem infeasible.
bool InsertVarValueEncodingInternal(int literal, int var, int64_t value,
bool add_constraints);
SolverLogger* logger_;

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ortools.sat.python.cp_model."""
import itertools
from absl.testing import absltest
import pandas as pd
@@ -95,6 +95,20 @@ class RecordSolution(cp_model.CpSolverSolutionCallback):
return self.__bool_var_values
class TimeRecorder(cp_model.CpSolverSolutionCallback):
def __init__(self, default_time: float) -> None:
super().__init__()
self.__last_time = default_time
def on_solution_callback(self) -> None:
self.__last_time = self.wall_time
@property
def last_time(self):
return self.__last_time
class LogToString:
"""Record log in a string."""
@@ -1649,6 +1663,215 @@ class CpModelTest(absltest.TestCase):
)
self.assertLen(model.proto.constraints, 13)
def testIssue4376SatModel(self):
print("testIssue4376SatModel")
letters: str = "BCFLMRT"
def symbols_from_string(text: str) -> list[int]:
return [letters.index(char) for char in text]
def rotate_symbols(symbols: list[int], turns: int) -> list[int]:
return symbols[turns:] + symbols[:turns]
data = """FMRC
FTLB
MCBR
FRTM
FBTM
BRFM
BTRM
BCRM
RTCF
TFRC
CTRM
CBTM
TFBM
TCBM
CFTM
BLTR
RLFM
CFLM
CRML
FCLR
FBTR
TBRF
RBCF
RBCT
BCTF
TFCR
CBRT
FCBT
FRTB
RBCM
MTFC
MFTC
MBFC
RTBM
RBFM
TRFM"""
tiles = [symbols_from_string(line) for line in data.splitlines()]
model = cp_model.CpModel()
# choices[i, x, y, r] is true iff we put tile i in cell (x,y) with
# rotation r.
choices = {}
for i in range(len(tiles)):
for x in range(6):
for y in range(6):
for r in range(4):
choices[(i, x, y, r)] = model.new_bool_var(
f"tile_{i}_{x}_{y}_{r}"
)
# corners[x, y, s] is true iff the corner at (x,y) contains symbol s.
corners = {}
for x in range(7):
for y in range(7):
for s in range(7):
corners[(x, y, s)] = model.new_bool_var(f"corner_{x}_{y}_{s}")
# Placing a tile puts a symbol in each corner.
for (i, x, y, r), choice in choices.items():
symbols = rotate_symbols(tiles[i], r)
model.add_implication(choice, corners[x, y, symbols[0]])
model.add_implication(choice, corners[x, y + 1, symbols[1]])
model.add_implication(choice, corners[x + 1, y + 1, symbols[2]])
model.add_implication(choice, corners[x + 1, y, symbols[3]])
# We must make exactly one choice for each tile.
for i in range(len(tiles)):
tmp_literals = []
for x in range(6):
for y in range(6):
for r in range(4):
tmp_literals.append(choices[(i, x, y, r)])
model.add_exactly_one(tmp_literals)
# We must make exactly one choice for each square.
for x, y in itertools.product(range(6), range(6)):
tmp_literals = []
for i in range(len(tiles)):
for r in range(4):
tmp_literals.append(choices[(i, x, y, r)])
model.add_exactly_one(tmp_literals)
# Each corner contains exactly one symbol.
for x, y in itertools.product(range(7), range(7)):
model.add_exactly_one(corners[x, y, s] for s in range(7))
# Solve.
solver = cp_model.CpSolver()
solver.parameters.num_workers = 8
solver.parameters.max_time_in_seconds = 20
solver.parameters.log_search_progress = True
solver.parameters.cp_model_presolve = False
solver.parameters.symmetry_level = 0
callback = TimeRecorder(solver.parameters.max_time_in_seconds)
solver.Solve(model, callback)
self.assertLess(solver.wall_time, callback.last_time + 5.0)
def testIssue4376MinimizeModel(self):
print("testIssue4376MinimizeModel")
model = cp_model.CpModel()
jobs = [
[3, 3], # [duration, width]
[2, 5],
[1, 3],
[3, 7],
[7, 3],
[2, 2],
[2, 2],
[5, 5],
[10, 2],
[4, 3],
[2, 6],
[1, 2],
[6, 8],
[4, 5],
[3, 7],
]
max_width = 10
horizon = sum(t[0] for t in jobs)
num_jobs = len(jobs)
all_jobs = range(num_jobs)
intervals = []
intervals0 = []
intervals1 = []
performed = []
starts = []
ends = []
demands = []
for i in all_jobs:
# Create main interval.
start = model.new_int_var(0, horizon, f"start_{i}")
duration = jobs[i][0]
end = model.new_int_var(0, horizon, f"end_{i}")
interval = model.new_interval_var(start, duration, end, f"interval_{i}")
starts.append(start)
intervals.append(interval)
ends.append(end)
demands.append(jobs[i][1])
# Create an optional copy of interval to be executed on machine 0.
performed_on_m0 = model.new_bool_var(f"perform_{i}_on_m0")
performed.append(performed_on_m0)
start0 = model.new_int_var(0, horizon, f"start_{i}_on_m0")
end0 = model.new_int_var(0, horizon, f"end_{i}_on_m0")
interval0 = model.new_optional_interval_var(
start0, duration, end0, performed_on_m0, f"interval_{i}_on_m0"
)
intervals0.append(interval0)
# Create an optional copy of interval to be executed on machine 1.
start1 = model.new_int_var(0, horizon, f"start_{i}_on_m1")
end1 = model.new_int_var(0, horizon, f"end_{i}_on_m1")
interval1 = model.new_optional_interval_var(
start1,
duration,
end1,
~performed_on_m0,
f"interval_{i}_on_m1",
)
intervals1.append(interval1)
# We only propagate the constraint if the tasks is performed on the
# machine.
model.add(start0 == start).only_enforce_if(performed_on_m0)
model.add(start1 == start).only_enforce_if(~performed_on_m0)
# Width constraint (modeled as a cumulative)
model.add_cumulative(intervals, demands, max_width)
# Choose which machine to perform the jobs on.
model.add_no_overlap(intervals0)
model.add_no_overlap(intervals1)
# Objective variable.
makespan = model.new_int_var(0, horizon, "makespan")
model.add_max_equality(makespan, ends)
model.minimize(makespan)
# Symmetry breaking.
model.add(performed[0] == 0)
# Solve.
solver = cp_model.CpSolver()
solver.parameters.num_workers = 8
solver.parameters.max_time_in_seconds = 50
solver.parameters.log_search_progress = True
callback = TimeRecorder(solver.parameters.max_time_in_seconds)
solver.Solve(model, callback)
self.assertLess(solver.wall_time, callback.last_time + 5.0)
if __name__ == "__main__":
absltest.main()

View File

@@ -15,7 +15,6 @@
#include <stdint.h>
#include <atomic>
#include <functional>
#include <string>
@@ -27,9 +26,9 @@
#include "ortools/sat/cp_model_utils.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_parameters.pb.h"
#include "ortools/sat/util.h"
#include "ortools/util/logging.h"
#include "ortools/util/sorted_interval_list.h"
#include "ortools/util/time_limit.h"
namespace operations_research {
namespace sat {
@@ -90,18 +89,15 @@ bool SolutionCallback::SolutionBooleanValue(int index) {
}
void SolutionCallback::StopSearch() {
if (stopped_ptr_ != nullptr) {
(*stopped_ptr_) = true;
}
if (wrapper_ != nullptr) wrapper_->StopSearch();
}
operations_research::sat::CpSolverResponse SolutionCallback::Response() const {
return response_;
}
void SolutionCallback::SetAtomicBooleanToStopTheSearch(
std::atomic<bool>* stopped_ptr) const {
stopped_ptr_ = stopped_ptr;
void SolutionCallback::SetWrapperClass(SolveWrapper* wrapper) const {
wrapper_ = wrapper;
}
bool SolutionCallback::HasResponse() const { return has_response_; }
@@ -116,15 +112,13 @@ void SolveWrapper::SetStringParameters(const std::string& string_parameters) {
}
void SolveWrapper::AddSolutionCallback(const SolutionCallback& callback) {
// Overwrite the atomic bool.
callback.SetAtomicBooleanToStopTheSearch(&stopped_);
callback.SetWrapperClass(this);
model_.Add(NewFeasibleSolutionObserver(
[&callback](const CpSolverResponse& r) { return callback.Run(r); }));
}
void SolveWrapper::ClearSolutionCallback(const SolutionCallback& callback) {
// cleanup the atomic bool.
callback.SetAtomicBooleanToStopTheSearch(nullptr);
callback.SetWrapperClass(nullptr); // Detach the wrapper class.
}
void SolveWrapper::AddLogCallback(
@@ -157,11 +151,13 @@ void SolveWrapper::AddBestBoundCallbackFromClass(BestBoundCallback* callback) {
operations_research::sat::CpSolverResponse SolveWrapper::Solve(
const operations_research::sat::CpModelProto& model_proto) {
FixFlagsAndEnvironmentForSwig();
model_.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(&stopped_);
return operations_research::sat::SolveCpModel(model_proto, &model_);
}
void SolveWrapper::StopSearch() { stopped_ = true; }
void SolveWrapper::StopSearch() {
model_.GetOrCreate<ModelSharedTimeLimit>()->Stop();
}
std::string CpSatHelper::ModelStats(
const operations_research::sat::CpModelProto& model_proto) {
return CpModelStats(model_proto);

View File

@@ -14,24 +14,20 @@
#ifndef OR_TOOLS_SAT_SWIG_HELPER_H_
#define OR_TOOLS_SAT_SWIG_HELPER_H_
#include <atomic>
#include <cstdint>
#include <functional>
#include <string>
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/cp_model_checker.h"
#include "ortools/sat/cp_model_solver.h"
#include "ortools/sat/cp_model_utils.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_parameters.pb.h"
#include "ortools/util/logging.h"
#include "ortools/util/sorted_interval_list.h"
#include "ortools/util/time_limit.h"
namespace operations_research {
namespace sat {
class SolveWrapper;
// Base class for SWIG director based on solution callbacks.
// See http://www.swig.org/Doc4.0/SWIGDocumentation.html#CSharp_directors.
class SolutionCallback {
@@ -72,14 +68,14 @@ class SolutionCallback {
operations_research::sat::CpSolverResponse Response() const;
// We use mutable and non const methods to overcome SWIG difficulties.
void SetAtomicBooleanToStopTheSearch(std::atomic<bool>* stopped_ptr) const;
void SetWrapperClass(SolveWrapper* wrapper) const;
bool HasResponse() const;
private:
mutable CpSolverResponse response_;
mutable bool has_response_ = false;
mutable std::atomic<bool>* stopped_ptr_;
mutable SolveWrapper* wrapper_ = nullptr;
};
// Simple director class for C#.
@@ -126,7 +122,6 @@ class SolveWrapper {
private:
Model model_;
std::atomic<bool> stopped_ = false;
};
// Static methods are stored in a module which name can vary.