regroup all reversible integer value in a single repository

This commit is contained in:
Laurent Perron
2017-07-06 04:56:28 -07:00
parent 49c7e20c37
commit 7b357a2263
4 changed files with 46 additions and 32 deletions

View File

@@ -331,17 +331,20 @@ LiteralIndex IntegerEncoder::SearchForLiteralAtOrBefore(
}
bool IntegerTrail::Propagate(Trail* trail) {
const int level = trail->CurrentDecisionLevel();
for (ReversibleInterface* rev : reversible_classes_) rev->SetLevel(level);
// Make sure that our internal "integer_decision_levels_" size matches the
// sat decision levels. At the level zero, integer_decision_levels_ should
// be empty.
if (trail->CurrentDecisionLevel() > integer_decision_levels_.size()) {
if (level > integer_decision_levels_.size()) {
integer_decision_levels_.push_back(integer_trail_.size());
CHECK_EQ(trail->CurrentDecisionLevel(), integer_decision_levels_.size());
}
// This is used to map any integer literal out of the initial variable domain
// into one that use one of the domain value.
var_to_current_lb_interval_index_.SetLevel(trail->CurrentDecisionLevel());
var_to_current_lb_interval_index_.SetLevel(level);
// Process all the "associated" literals and Enqueue() the corresponding
// bounds.
@@ -359,16 +362,17 @@ bool IntegerTrail::Propagate(Trail* trail) {
}
void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) {
var_to_current_lb_interval_index_.SetLevel(trail.CurrentDecisionLevel());
const int level = trail.CurrentDecisionLevel();
for (ReversibleInterface* rev : reversible_classes_) rev->SetLevel(level);
var_to_current_lb_interval_index_.SetLevel(level);
propagation_trail_index_ =
std::min(propagation_trail_index_, literal_trail_index);
// Note that if a conflict was detected before Propagate() of this class was
// even called, it is possible that there is nothing to backtrack.
const int decision_level = trail.CurrentDecisionLevel();
if (decision_level >= integer_decision_levels_.size()) return;
const int target = integer_decision_levels_[decision_level];
integer_decision_levels_.resize(decision_level);
if (level >= integer_decision_levels_.size()) return;
const int target = integer_decision_levels_[level];
integer_decision_levels_.resize(level);
CHECK_GE(target, vars_.size());
// This is needed for the code below to work.
@@ -1034,14 +1038,8 @@ void IntegerTrail::EnqueueLiteral(
GenericLiteralWatcher::GenericLiteralWatcher(Model* model)
: SatPropagator("GenericLiteralWatcher"),
integer_trail_(model->GetOrCreate<IntegerTrail>()) {
// TODO(user): Have a general mecanism to register "global" reversible
// classes and keep them synchronized with the search.
std::unique_ptr<RevRepository<int>> rev_int_repository(
new RevRepository<int>());
rev_int_repository_ = rev_int_repository.get();
model->SetSingleton(std::move(rev_int_repository));
integer_trail_(model->GetOrCreate<IntegerTrail>()),
rev_int_repository_(model->GetOrCreate<RevIntRepository>()) {
// TODO(user): This propagator currently needs to be last because it is the
// only one enforcing that a fix-point is reached on the integer variables.
// Figure out a better interaction between the sat propagation loop and
@@ -1086,7 +1084,6 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) {
bool GenericLiteralWatcher::Propagate(Trail* trail) {
const int level = trail->CurrentDecisionLevel();
rev_int_repository_->SetLevel(level);
UpdateCallingNeeds(trail);
// Note that the priority may be set to -1 inside the loop in order to restart
@@ -1201,7 +1198,6 @@ void GenericLiteralWatcher::Untrail(const Trail& trail, int trail_index) {
in_queue_.assign(watchers_.size(), false);
const int level = trail.CurrentDecisionLevel();
rev_int_repository_->SetLevel(level);
for (int& ref : id_to_greatest_common_level_since_last_call_) {
ref = std::min(ref, level);
}

View File

@@ -549,6 +549,12 @@ class IntegerTrail : public SatPropagator {
return vars_[var.value()].current_trail_index < vars_.size();
}
// Registers a reversible class. This class will always be synced with the
// correct decision level.
void RegisterReversibleClass(ReversibleInterface* rev) {
reversible_classes_.push_back(rev);
}
private:
// Tests that all the literals in the given reason are assigned to false.
// This is used to DCHECK the given reasons to the Enqueue*() functions.
@@ -659,6 +665,7 @@ class IntegerTrail : public SatPropagator {
int64 num_enqueues_;
std::vector<SparseBitset<IntegerVariable>*> watchers_;
std::vector<ReversibleInterface*> reversible_classes_;
IntegerDomains* domains_;
IntegerEncoder* encoder_;
@@ -702,6 +709,21 @@ class PropagatorInterface {
}
};
// Singleton for basic reversible types. We need the wrapper so that they can be
// accessed with model->GetOrCreate<>() and properly registered at creation.
class RevIntRepository : public RevRepository<int> {
public:
explicit RevIntRepository(Model* model) {
model->GetOrCreate<IntegerTrail>()->RegisterReversibleClass(this);
}
};
class RevIntegerValueRepository : public RevRepository<IntegerValue> {
public:
explicit RevIntegerValueRepository(Model* model) {
model->GetOrCreate<IntegerTrail>()->RegisterReversibleClass(this);
}
};
// This class allows registering Propagator that will be called if a
// watched Literal or LbVar changes.
//
@@ -779,7 +801,7 @@ class GenericLiteralWatcher : public SatPropagator {
void UpdateCallingNeeds(Trail* trail);
IntegerTrail* integer_trail_;
RevRepository<int>* rev_int_repository_;
RevIntRepository* rev_int_repository_;
struct WatchData {
int id;

View File

@@ -23,12 +23,13 @@ namespace sat {
IntegerSumLE::IntegerSumLE(LiteralIndex reified_literal,
const std::vector<IntegerVariable>& vars,
const std::vector<IntegerValue>& coeffs,
IntegerValue upper, Trail* trail,
IntegerTrail* integer_trail)
IntegerValue upper, Model* model)
: reified_literal_(reified_literal),
upper_bound_(upper),
trail_(trail),
integer_trail_(integer_trail),
trail_(model->GetOrCreate<Trail>()),
integer_trail_(model->GetOrCreate<IntegerTrail>()),
rev_integer_value_repository_(
model->GetOrCreate<RevIntegerValueRepository>()),
vars_(vars),
coeffs_(coeffs) {
// TODO(user): deal with this corner case.
@@ -76,7 +77,7 @@ bool IntegerSumLE::Propagate() {
}
// Save the current number of fixed variables.
rev_repository_integer_value_.SaveState(&rev_lb_fixed_vars_);
rev_integer_value_repository_->SaveState(&rev_lb_fixed_vars_);
// Compute the new lower bound and update the reversible structures.
IntegerValue lb_unfixed_vars = IntegerValue(0);
@@ -192,7 +193,6 @@ void IntegerSumLE::RegisterWith(GenericLiteralWatcher* watcher) {
watcher->WatchLiteral(Literal(reified_literal_), id);
}
watcher->RegisterReversibleInt(id, &rev_num_fixed_vars_);
watcher->RegisterReversibleClass(id, &rev_repository_integer_value_);
}
MinPropagator::MinPropagator(const std::vector<IntegerVariable>& vars,

View File

@@ -44,8 +44,7 @@ class IntegerSumLE : public PropagatorInterface {
IntegerSumLE(LiteralIndex reified_literal,
const std::vector<IntegerVariable>& vars,
const std::vector<IntegerValue>& coefficients,
IntegerValue upper_bound, Trail* trail,
IntegerTrail* integer_trail);
IntegerValue upper_bound, Model* model);
// We propagate:
// - If the sum of the individual lower-bound is > upper_bound, we fail.
@@ -66,8 +65,7 @@ class IntegerSumLE : public PropagatorInterface {
Trail* trail_;
IntegerTrail* integer_trail_;
RevRepository<IntegerValue> rev_repository_integer_value_;
RevIntegerValueRepository* rev_integer_value_repository_;
// Reversible sum of the lower bound of the fixed variables.
IntegerValue rev_lb_fixed_vars_;
@@ -213,8 +211,7 @@ inline std::function<void(Model*)> WeightedSumLowerOrEqual(
IntegerSumLE* constraint = new IntegerSumLE(
kNoLiteralIndex, vars,
std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
IntegerValue(upper_bound), model->GetOrCreate<Trail>(),
model->GetOrCreate<IntegerTrail>());
IntegerValue(upper_bound), model);
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
};
@@ -281,8 +278,7 @@ inline std::function<void(Model*)> ConditionalWeightedSumLowerOrEqual(
IntegerSumLE* constraint = new IntegerSumLE(
is_le.Index(), vars,
std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
IntegerValue(upper_bound), model->GetOrCreate<Trail>(),
model->GetOrCreate<IntegerTrail>());
IntegerValue(upper_bound), model);
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
};