regroup all reversible integer value in a single repository
This commit is contained in:
@@ -331,17 +331,20 @@ LiteralIndex IntegerEncoder::SearchForLiteralAtOrBefore(
|
||||
}
|
||||
|
||||
bool IntegerTrail::Propagate(Trail* trail) {
|
||||
const int level = trail->CurrentDecisionLevel();
|
||||
for (ReversibleInterface* rev : reversible_classes_) rev->SetLevel(level);
|
||||
|
||||
// Make sure that our internal "integer_decision_levels_" size matches the
|
||||
// sat decision levels. At the level zero, integer_decision_levels_ should
|
||||
// be empty.
|
||||
if (trail->CurrentDecisionLevel() > integer_decision_levels_.size()) {
|
||||
if (level > integer_decision_levels_.size()) {
|
||||
integer_decision_levels_.push_back(integer_trail_.size());
|
||||
CHECK_EQ(trail->CurrentDecisionLevel(), integer_decision_levels_.size());
|
||||
}
|
||||
|
||||
// This is used to map any integer literal out of the initial variable domain
|
||||
// into one that use one of the domain value.
|
||||
var_to_current_lb_interval_index_.SetLevel(trail->CurrentDecisionLevel());
|
||||
var_to_current_lb_interval_index_.SetLevel(level);
|
||||
|
||||
// Process all the "associated" literals and Enqueue() the corresponding
|
||||
// bounds.
|
||||
@@ -359,16 +362,17 @@ bool IntegerTrail::Propagate(Trail* trail) {
|
||||
}
|
||||
|
||||
void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) {
|
||||
var_to_current_lb_interval_index_.SetLevel(trail.CurrentDecisionLevel());
|
||||
const int level = trail.CurrentDecisionLevel();
|
||||
for (ReversibleInterface* rev : reversible_classes_) rev->SetLevel(level);
|
||||
var_to_current_lb_interval_index_.SetLevel(level);
|
||||
propagation_trail_index_ =
|
||||
std::min(propagation_trail_index_, literal_trail_index);
|
||||
|
||||
// Note that if a conflict was detected before Propagate() of this class was
|
||||
// even called, it is possible that there is nothing to backtrack.
|
||||
const int decision_level = trail.CurrentDecisionLevel();
|
||||
if (decision_level >= integer_decision_levels_.size()) return;
|
||||
const int target = integer_decision_levels_[decision_level];
|
||||
integer_decision_levels_.resize(decision_level);
|
||||
if (level >= integer_decision_levels_.size()) return;
|
||||
const int target = integer_decision_levels_[level];
|
||||
integer_decision_levels_.resize(level);
|
||||
CHECK_GE(target, vars_.size());
|
||||
|
||||
// This is needed for the code below to work.
|
||||
@@ -1034,14 +1038,8 @@ void IntegerTrail::EnqueueLiteral(
|
||||
|
||||
GenericLiteralWatcher::GenericLiteralWatcher(Model* model)
|
||||
: SatPropagator("GenericLiteralWatcher"),
|
||||
integer_trail_(model->GetOrCreate<IntegerTrail>()) {
|
||||
// TODO(user): Have a general mecanism to register "global" reversible
|
||||
// classes and keep them synchronized with the search.
|
||||
std::unique_ptr<RevRepository<int>> rev_int_repository(
|
||||
new RevRepository<int>());
|
||||
rev_int_repository_ = rev_int_repository.get();
|
||||
model->SetSingleton(std::move(rev_int_repository));
|
||||
|
||||
integer_trail_(model->GetOrCreate<IntegerTrail>()),
|
||||
rev_int_repository_(model->GetOrCreate<RevIntRepository>()) {
|
||||
// TODO(user): This propagator currently needs to be last because it is the
|
||||
// only one enforcing that a fix-point is reached on the integer variables.
|
||||
// Figure out a better interaction between the sat propagation loop and
|
||||
@@ -1086,7 +1084,6 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) {
|
||||
|
||||
bool GenericLiteralWatcher::Propagate(Trail* trail) {
|
||||
const int level = trail->CurrentDecisionLevel();
|
||||
rev_int_repository_->SetLevel(level);
|
||||
UpdateCallingNeeds(trail);
|
||||
|
||||
// Note that the priority may be set to -1 inside the loop in order to restart
|
||||
@@ -1201,7 +1198,6 @@ void GenericLiteralWatcher::Untrail(const Trail& trail, int trail_index) {
|
||||
in_queue_.assign(watchers_.size(), false);
|
||||
|
||||
const int level = trail.CurrentDecisionLevel();
|
||||
rev_int_repository_->SetLevel(level);
|
||||
for (int& ref : id_to_greatest_common_level_since_last_call_) {
|
||||
ref = std::min(ref, level);
|
||||
}
|
||||
|
||||
@@ -549,6 +549,12 @@ class IntegerTrail : public SatPropagator {
|
||||
return vars_[var.value()].current_trail_index < vars_.size();
|
||||
}
|
||||
|
||||
// Registers a reversible class. This class will always be synced with the
|
||||
// correct decision level.
|
||||
void RegisterReversibleClass(ReversibleInterface* rev) {
|
||||
reversible_classes_.push_back(rev);
|
||||
}
|
||||
|
||||
private:
|
||||
// Tests that all the literals in the given reason are assigned to false.
|
||||
// This is used to DCHECK the given reasons to the Enqueue*() functions.
|
||||
@@ -659,6 +665,7 @@ class IntegerTrail : public SatPropagator {
|
||||
int64 num_enqueues_;
|
||||
|
||||
std::vector<SparseBitset<IntegerVariable>*> watchers_;
|
||||
std::vector<ReversibleInterface*> reversible_classes_;
|
||||
|
||||
IntegerDomains* domains_;
|
||||
IntegerEncoder* encoder_;
|
||||
@@ -702,6 +709,21 @@ class PropagatorInterface {
|
||||
}
|
||||
};
|
||||
|
||||
// Singleton for basic reversible types. We need the wrapper so that they can be
|
||||
// accessed with model->GetOrCreate<>() and properly registered at creation.
|
||||
class RevIntRepository : public RevRepository<int> {
|
||||
public:
|
||||
explicit RevIntRepository(Model* model) {
|
||||
model->GetOrCreate<IntegerTrail>()->RegisterReversibleClass(this);
|
||||
}
|
||||
};
|
||||
class RevIntegerValueRepository : public RevRepository<IntegerValue> {
|
||||
public:
|
||||
explicit RevIntegerValueRepository(Model* model) {
|
||||
model->GetOrCreate<IntegerTrail>()->RegisterReversibleClass(this);
|
||||
}
|
||||
};
|
||||
|
||||
// This class allows registering Propagator that will be called if a
|
||||
// watched Literal or LbVar changes.
|
||||
//
|
||||
@@ -779,7 +801,7 @@ class GenericLiteralWatcher : public SatPropagator {
|
||||
void UpdateCallingNeeds(Trail* trail);
|
||||
|
||||
IntegerTrail* integer_trail_;
|
||||
RevRepository<int>* rev_int_repository_;
|
||||
RevIntRepository* rev_int_repository_;
|
||||
|
||||
struct WatchData {
|
||||
int id;
|
||||
|
||||
@@ -23,12 +23,13 @@ namespace sat {
|
||||
IntegerSumLE::IntegerSumLE(LiteralIndex reified_literal,
|
||||
const std::vector<IntegerVariable>& vars,
|
||||
const std::vector<IntegerValue>& coeffs,
|
||||
IntegerValue upper, Trail* trail,
|
||||
IntegerTrail* integer_trail)
|
||||
IntegerValue upper, Model* model)
|
||||
: reified_literal_(reified_literal),
|
||||
upper_bound_(upper),
|
||||
trail_(trail),
|
||||
integer_trail_(integer_trail),
|
||||
trail_(model->GetOrCreate<Trail>()),
|
||||
integer_trail_(model->GetOrCreate<IntegerTrail>()),
|
||||
rev_integer_value_repository_(
|
||||
model->GetOrCreate<RevIntegerValueRepository>()),
|
||||
vars_(vars),
|
||||
coeffs_(coeffs) {
|
||||
// TODO(user): deal with this corner case.
|
||||
@@ -76,7 +77,7 @@ bool IntegerSumLE::Propagate() {
|
||||
}
|
||||
|
||||
// Save the current number of fixed variables.
|
||||
rev_repository_integer_value_.SaveState(&rev_lb_fixed_vars_);
|
||||
rev_integer_value_repository_->SaveState(&rev_lb_fixed_vars_);
|
||||
|
||||
// Compute the new lower bound and update the reversible structures.
|
||||
IntegerValue lb_unfixed_vars = IntegerValue(0);
|
||||
@@ -192,7 +193,6 @@ void IntegerSumLE::RegisterWith(GenericLiteralWatcher* watcher) {
|
||||
watcher->WatchLiteral(Literal(reified_literal_), id);
|
||||
}
|
||||
watcher->RegisterReversibleInt(id, &rev_num_fixed_vars_);
|
||||
watcher->RegisterReversibleClass(id, &rev_repository_integer_value_);
|
||||
}
|
||||
|
||||
MinPropagator::MinPropagator(const std::vector<IntegerVariable>& vars,
|
||||
|
||||
@@ -44,8 +44,7 @@ class IntegerSumLE : public PropagatorInterface {
|
||||
IntegerSumLE(LiteralIndex reified_literal,
|
||||
const std::vector<IntegerVariable>& vars,
|
||||
const std::vector<IntegerValue>& coefficients,
|
||||
IntegerValue upper_bound, Trail* trail,
|
||||
IntegerTrail* integer_trail);
|
||||
IntegerValue upper_bound, Model* model);
|
||||
|
||||
// We propagate:
|
||||
// - If the sum of the individual lower-bound is > upper_bound, we fail.
|
||||
@@ -66,8 +65,7 @@ class IntegerSumLE : public PropagatorInterface {
|
||||
|
||||
Trail* trail_;
|
||||
IntegerTrail* integer_trail_;
|
||||
|
||||
RevRepository<IntegerValue> rev_repository_integer_value_;
|
||||
RevIntegerValueRepository* rev_integer_value_repository_;
|
||||
|
||||
// Reversible sum of the lower bound of the fixed variables.
|
||||
IntegerValue rev_lb_fixed_vars_;
|
||||
@@ -213,8 +211,7 @@ inline std::function<void(Model*)> WeightedSumLowerOrEqual(
|
||||
IntegerSumLE* constraint = new IntegerSumLE(
|
||||
kNoLiteralIndex, vars,
|
||||
std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
|
||||
IntegerValue(upper_bound), model->GetOrCreate<Trail>(),
|
||||
model->GetOrCreate<IntegerTrail>());
|
||||
IntegerValue(upper_bound), model);
|
||||
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
|
||||
model->TakeOwnership(constraint);
|
||||
};
|
||||
@@ -281,8 +278,7 @@ inline std::function<void(Model*)> ConditionalWeightedSumLowerOrEqual(
|
||||
IntegerSumLE* constraint = new IntegerSumLE(
|
||||
is_le.Index(), vars,
|
||||
std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
|
||||
IntegerValue(upper_bound), model->GetOrCreate<Trail>(),
|
||||
model->GetOrCreate<IntegerTrail>());
|
||||
IntegerValue(upper_bound), model);
|
||||
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
|
||||
model->TakeOwnership(constraint);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user