more work on sat solver: nlog(n) all different on bounds; improve difference logic propagator

This commit is contained in:
Laurent Perron
2016-10-13 10:44:35 -07:00
parent b067d36db8
commit 3bc1b082fd
10 changed files with 377 additions and 88 deletions

View File

@@ -1051,7 +1051,7 @@ void SolveWithSat(const fz::Model& fz_model, const fz::FlatzincParameters& p,
FZLOG << "Num integer variables = "
<< m.model.Get<IntegerTrail>()->NumIntegerVariables() / 2 << FZENDL;
FZLOG << "Num fully encoded variable = " << num_fully_encoded_variables / 2
<< FZENDL;
<< FZENDL;
FZLOG << "Num Boolean variables created = "
<< m.model.Get<SatSolver>()->NumVariables() << FZENDL;
FZLOG << "Num constants = " << m.constant_map.size() << FZENDL;

View File

@@ -72,6 +72,117 @@ void BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
}
}
AllDifferentBoundsPropagator::AllDifferentBoundsPropagator(
const std::vector<IntegerVariable>& vars, IntegerTrail* integer_trail)
: vars_(vars), integer_trail_(integer_trail), num_calls_(0) {
for (int i = 0; i < vars.size(); ++i) {
negated_vars_.push_back(NegationOf(vars_[i]));
}
}
bool AllDifferentBoundsPropagator::Propagate(Trail* trail) {
if (vars_.empty()) return true;
if (!PropagateLowerBounds(trail)) return false;
// Note that it is not required to swap back vars_ and negated_vars_.
// TODO(user): investigate the impact.
std::swap(vars_, negated_vars_);
const bool result = PropagateLowerBounds(trail);
std::swap(vars_, negated_vars_);
return result;
}
// TODO(user): we could gain by pushing all the new bound at the end, so that
// we just have to sort to_insert_ once.
void AllDifferentBoundsPropagator::FillHallReason(IntegerValue hall_lb,
IntegerValue hall_ub) {
for (auto entry : to_insert_) {
value_to_variable_[entry.first] = entry.second;
}
to_insert_.clear();
integer_reason_.clear();
for (int64 v = hall_lb.value(); v <= hall_ub; ++v) {
const IntegerVariable var = FindOrDie(value_to_variable_, v);
integer_reason_.push_back(IntegerLiteral::GreaterOrEqual(var, hall_lb));
integer_reason_.push_back(IntegerLiteral::LowerOrEqual(var, hall_ub));
}
}
bool AllDifferentBoundsPropagator::PropagateLowerBounds(Trail* trail) {
++num_calls_;
critical_intervals_.clear();
hall_starts_.clear();
hall_ends_.clear();
to_insert_.clear();
if (num_calls_ % 20 == 0) {
// We don't really need to clear this, but we do from time to time to
// save memory (in case the variable domains are huge). This optimization
// helps a bit.
value_to_variable_.clear();
}
// Loop over the variables by increasing ub.
std::sort(
vars_.begin(), vars_.end(), [this](IntegerVariable a, IntegerVariable b) {
return integer_trail_->UpperBound(a) < integer_trail_->UpperBound(b);
});
for (const IntegerVariable var : vars_) {
const IntegerValue lb = integer_trail_->LowerBound(var);
// Check if lb is in an Hall interval, and push it if this is the case.
const int hall_index =
std::lower_bound(hall_ends_.begin(), hall_ends_.end(), lb) -
hall_ends_.begin();
if (hall_index < hall_ends_.size() && hall_starts_[hall_index] <= lb) {
const IntegerValue hs = hall_starts_[hall_index];
const IntegerValue he = hall_ends_[hall_index];
FillHallReason(hs, he);
integer_reason_.push_back(IntegerLiteral::GreaterOrEqual(var, hs));
if (!integer_trail_->Enqueue(IntegerLiteral::GreaterOrEqual(var, he + 1),
/*literal_reason=*/{}, integer_reason_)) {
return false;
}
}
// Updates critical_intervals_. Note that we use the old lb, but that
// doesn't change the value of newly_covered. This block is what takes the
// most time.
int64 newly_covered;
const auto it =
critical_intervals_.GrowRightByOne(lb.value(), &newly_covered);
to_insert_.push_back({newly_covered, var});
const IntegerValue end(it->end);
// We cannot have a conflict, because it should have beend detected before
// by pushing an interval lower bound past its upper bound.
DCHECK_LE(end, integer_trail_->UpperBound(var));
// If we have a new Hall interval, add it to the set. Note that it will
// always be last, and if it overlaps some previous Hall intervals, it
// always overlaps them fully.
if (end == integer_trail_->UpperBound(var)) {
const IntegerValue start(it->start);
while (!hall_starts_.empty() && start <= hall_starts_.back()) {
hall_starts_.pop_back();
hall_ends_.pop_back();
}
DCHECK(hall_ends_.empty() || hall_ends_.back() < start);
hall_starts_.push_back(start);
hall_ends_.push_back(end);
}
}
return true;
}
void AllDifferentBoundsPropagator::RegisterWith(
GenericLiteralWatcher* watcher) {
const int id = watcher->Register(this);
for (const IntegerVariable& var : vars_) {
watcher->WatchIntegerVariable(var, id);
}
}
std::function<void(Model*)> AllDifferent(const std::vector<IntegerVariable>& vars) {
return [=](Model* model) {
hash_set<IntegerValue> fixed_values;

View File

@@ -14,8 +14,11 @@
#ifndef OR_TOOLS_SAT_CP_CONSTRAINTS_H_
#define OR_TOOLS_SAT_CP_CONSTRAINTS_H_
#include <unordered_map>
#include "sat/integer.h"
#include "sat/model.h"
#include "util/sorted_interval_list.h"
namespace operations_research {
namespace sat {
@@ -42,13 +45,70 @@ class BooleanXorPropagator : public PropagatorInterface {
DISALLOW_COPY_AND_ASSIGN(BooleanXorPropagator);
};
// Implement the all different bound consistent propagator with explanation.
// That is, given n variables that must be all different, this propagates the
// bounds of each variables as much as possible. The key is to detect the so
// called Hall interval which are interval of size k that contains the domain
// of k variables. Because all the variables must take different values, we can
// deduce that the domain of the other variables cannot contains such Hall
// interval.
//
// We use a "simple" O(n log n) algorithm.
//
// TODO(user): implement the faster algorithm described in:
// https://cs.uwaterloo.ca/~vanbeek/Publications/ijcai03_TR.pdf
// Note that the algorithms are similar, the gain comes by replacing our
// SortedDisjointIntervalList with a more customized class for our operations.
// It is even possible to get an O(n) complexity if the values of the bounds are
// in a range of size O(n).
class AllDifferentBoundsPropagator : public PropagatorInterface {
public:
AllDifferentBoundsPropagator(const std::vector<IntegerVariable>& vars,
IntegerTrail* integer_trail);
bool Propagate(Trail* trail) final;
void RegisterWith(GenericLiteralWatcher* watcher);
private:
// Fills integer_reason_ with the reason why we have the given hall interval.
void FillHallReason(IntegerValue hall_lb, IntegerValue hall_ub);
// Do half the job of Propagate().
bool PropagateLowerBounds(Trail* trail);
std::vector<IntegerVariable> vars_;
std::vector<IntegerVariable> negated_vars_;
IntegerTrail* integer_trail_;
// The sets of "critical" intervals. This has the same meaning as in the
// disjunctive constraint.
SortedDisjointIntervalList critical_intervals_;
// The list of Hall intervalls detected so far, sorted.
std::vector<IntegerValue> hall_starts_;
std::vector<IntegerValue> hall_ends_;
// Members needed for explaining the propagation.
//
// The IntegerVariable in an hall interval [lb, ub] are the variables with key
// in [lb, ub] in this map. Note(user): if the set of bounds is small, we
// could use a vector here. The O(ub - lb) to create the reason is fine since
// this is the size of the reason.
//
// Optimization: we only insert the entry in the map lazily when the reason
// is needed.
int64 num_calls_;
std::vector<std::pair<int64, IntegerVariable>> to_insert_;
std::unordered_map<int64, IntegerVariable> value_to_variable_;
std::vector<IntegerLiteral> integer_reason_;
DISALLOW_COPY_AND_ASSIGN(AllDifferentBoundsPropagator);
};
// ============================================================================
// Model based functions.
// ============================================================================
// Enforces that the given tuple of variables takes different values.
std::function<void(Model*)> AllDifferent(const std::vector<IntegerVariable>& vars);
// Enforces the XOR of a set of literals to be equal to the given value.
inline std::function<void(Model*)> LiteralXorIs(const std::vector<Literal>& literals,
bool value) {
@@ -61,6 +121,28 @@ inline std::function<void(Model*)> LiteralXorIs(const std::vector<Literal>& lite
};
}
// Enforces that the given tuple of variables takes different values.
std::function<void(Model*)> AllDifferent(const std::vector<IntegerVariable>& vars);
// Enforces that the given tuple of variables takes different values.
// Same as AllDifferent() but use a different propagator that only enforce
// the so called "bound consistency" on the variable domains.
//
// Compared to AllDifferent() this doesn't require fully encoding the variables
// and it is also quite fast. Note that the propagation is different, this will
// not remove already taken values from inside a domain, but it will propagates
// more the domain bounds.
inline std::function<void(Model*)> AllDifferentOnBounds(
const std::vector<IntegerVariable>& vars) {
return [=](Model* model) {
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
AllDifferentBoundsPropagator* constraint =
new AllDifferentBoundsPropagator(vars, integer_trail);
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
};
}
} // namespace sat
} // namespace operations_research

View File

@@ -126,7 +126,7 @@ void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit,
// Associate the new literal to i_lit.
AddImplications(i_lit, literal);
reverse_encoding_[literal.Index()] = i_lit;
reverse_encoding_[literal.Index()].push_back(i_lit);
// Add its negation and associated it with i_lit.Negated().
//
@@ -134,7 +134,7 @@ void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit,
// 100% sure why!! I think it works because these literals can only appear
// in a conflict if the presence literal of the optional variables is true.
AddImplications(i_lit.Negated(), literal.Negated());
reverse_encoding_[literal.NegatedIndex()] = i_lit.Negated();
reverse_encoding_[literal.NegatedIndex()].push_back(i_lit.Negated());
}
Literal IntegerEncoder::CreateAssociatedLiteral(IntegerLiteral i_lit) {
@@ -232,8 +232,7 @@ bool IntegerTrail::Propagate(Trail* trail) {
const Literal literal = (*trail)[propagation_trail_index_++];
// Bound encoder.
const IntegerLiteral i_lit = encoder_->GetIntegerLiteral(literal);
if (i_lit.var >= 0) {
for (const IntegerLiteral i_lit : encoder_->GetIntegerLiterals(literal)) {
// The reason is simply the associated literal.
if (!Enqueue(i_lit, {literal.Negated()}, {})) return false;
}

View File

@@ -141,6 +141,8 @@ inline std::ostream& operator<<(std::ostream& os, IntegerLiteral i_lit) {
return os;
}
using InlinedIntegerLiteralVector = std::vector<IntegerLiteral>;
// Each integer variable x will be associated with a set of literals encoding
// (x >= v) for some values of v. This class maintains the relationship between
// the integer variables and such literals which can be created by a call to
@@ -260,11 +262,11 @@ class IntegerEncoder {
// Same as CreateAssociatedLiteral() but safe to call if already created.
Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit);
// Returns the IntegerLiteral that was associated with the given Boolean
// literal or an IntegerLiteral with a variable set to kNoIntegerVariable if
// the argument does not correspond to such literal.
IntegerLiteral GetIntegerLiteral(Literal lit) const {
if (lit.Index() >= reverse_encoding_.size()) return IntegerLiteral();
// Returns the IntegerLiterals that were associated with the given Literal.
const InlinedIntegerLiteralVector& GetIntegerLiterals(Literal lit) const {
if (lit.Index() >= reverse_encoding_.size()) {
return empty_integer_literal_vector_;
}
return reverse_encoding_[lit.Index()];
}
@@ -291,10 +293,9 @@ class IntegerEncoder {
// corresponding to the same variable).
ITIVector<IntegerVariable, std::map<IntegerValue, Literal>> encoding_by_var_;
// Store for a given LiteralIndex its associated IntegerLiteral or an
// IntegerLiteral with kNoIntegerVariable as a variable if the LiteralIndex
// doesn't correspond to an IntegerLiteral.
ITIVector<LiteralIndex, IntegerLiteral> reverse_encoding_;
// Store for a given LiteralIndex the list of its associated IntegerLiterals.
const InlinedIntegerLiteralVector empty_integer_literal_vector_;
ITIVector<LiteralIndex, InlinedIntegerLiteralVector> reverse_encoding_;
// Full domain encoding. The map contains the index in full_encoding_ of
// the fully encoded variable. Each entry in full_encoding_ is sorted by
@@ -744,12 +745,7 @@ inline std::function<void(Model*)> Equality(IntegerVariable v, int64 value) {
inline std::function<void(Model*)> Equality(IntegerLiteral i, Literal l) {
return [=](Model* model) {
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
// Tricky: currently we cannot associate the same literal to two different
// IntegerLiteral! The second test verifies that l is not already
// associated.
if (encoder->LiteralIsAssociated(i) ||
encoder->GetIntegerLiteral(l) != IntegerLiteral()) {
if (encoder->LiteralIsAssociated(i)) {
const Literal current = encoder->GetOrCreateAssociatedLiteral(i);
model->Add(Equality(current, l));
} else {

View File

@@ -199,7 +199,6 @@ inline std::function<void(Model*)> WeightedSumLowerOrEqual(
const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
int64 upper_bound) {
// Special cases.
// TODO(user): Do the same for the reified case.
CHECK_GE(vars.size(), 1) << "Should be encoded differently.";
if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) &&
(coefficients[1] == 1 || coefficients[1] == -1)) {
@@ -230,7 +229,7 @@ template <typename VectorInt>
inline std::function<void(Model*)> WeightedSumGreaterOrEqual(
const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
int64 lower_bound) {
// We just negate everything and use an IntegerSumLE() constraints.
// We just negate everything and use an <= constraints.
std::vector<IntegerValue> negated_coeffs(coefficients.begin(), coefficients.end());
for (IntegerValue& ref : negated_coeffs) ref = -ref;
return WeightedSumLowerOrEqual(vars, negated_coeffs, -lower_bound);
@@ -247,33 +246,61 @@ inline std::function<void(Model*)> FixedWeightedSum(
};
}
// is_le => sum <= upper_bound
template <typename VectorInt>
inline std::function<void(Model*)> ConditionalWeightedSumLowerOrEqual(
Literal is_le, const std::vector<IntegerVariable>& vars,
const VectorInt& coefficients, int64 upper_bound) {
// Special cases.
CHECK_GE(vars.size(), 1) << "Should be encoded differently.";
if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) &&
(coefficients[1] == 1 || coefficients[1] == -1)) {
return ConditionalSum2LowerOrEqual(
coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]), upper_bound,
is_le);
}
if (vars.size() == 3 && (coefficients[0] == 1 || coefficients[0] == -1) &&
(coefficients[1] == 1 || coefficients[1] == -1) &&
(coefficients[2] == 1 || coefficients[2] == -1)) {
return ConditionalSum3LowerOrEqual(
coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]),
coefficients[2] == 1 ? vars[2] : NegationOf(vars[2]), upper_bound,
is_le);
}
return [=](Model* model) {
IntegerSumLE* constraint = new IntegerSumLE(
is_le.Index(), vars,
std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
IntegerValue(upper_bound), model->GetOrCreate<IntegerTrail>());
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
};
}
// is_ge => sum >= lower_bound
template <typename VectorInt>
inline std::function<void(Model*)> ConditionalWeightedSumGreaterOrEqual(
Literal is_ge, const std::vector<IntegerVariable>& vars,
const VectorInt& coefficients, int64 lower_bound) {
// We just negate everything and use an <= constraint.
std::vector<IntegerValue> negated_coeffs(coefficients.begin(), coefficients.end());
for (IntegerValue& ref : negated_coeffs) ref = -ref;
return ConditionalWeightedSumLowerOrEqual(is_ge, vars, negated_coeffs,
-lower_bound);
}
// Weighted sum <= constant reified.
template <typename VectorInt>
inline std::function<void(Model*)> WeightedSumLowerOrEqualReif(
Literal is_le, const std::vector<IntegerVariable>& vars,
const VectorInt& coefficients, int64 upper_bound) {
return [=](Model* model) {
// is_le => lin <= upper_bound
{
IntegerSumLE* constraint = new IntegerSumLE(
is_le.Index(), vars,
std::vector<IntegerValue>(coefficients.begin(), coefficients.end()),
IntegerValue(upper_bound), model->GetOrCreate<IntegerTrail>());
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
}
// not(is_le) => lin > upper_bound, i.e -lin <= -upper_bound - 1
{
std::vector<IntegerValue> negated_coeffs(coefficients.begin(),
coefficients.end());
for (IntegerValue& ref : negated_coeffs) ref = -ref;
IntegerSumLE* constraint = new IntegerSumLE(
is_le.NegatedIndex(), vars, negated_coeffs,
IntegerValue(-upper_bound - 1), model->GetOrCreate<IntegerTrail>());
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
}
model->Add(ConditionalWeightedSumLowerOrEqual(is_le, vars, coefficients,
upper_bound));
model->Add(ConditionalWeightedSumGreaterOrEqual(
is_le.Negated(), vars, coefficients, upper_bound + 1));
};
}
@@ -282,8 +309,12 @@ template <typename VectorInt>
inline std::function<void(Model*)> WeightedSumGreaterOrEqualReif(
Literal is_ge, const std::vector<IntegerVariable>& vars,
const VectorInt& coefficients, int64 lower_bound) {
return WeightedSumLowerOrEqualReif(is_ge.Negated(), vars, coefficients,
lower_bound - 1);
return [=](Model* model) {
model->Add(ConditionalWeightedSumGreaterOrEqual(is_ge, vars, coefficients,
lower_bound));
model->Add(ConditionalWeightedSumLowerOrEqual(
is_ge.Negated(), vars, coefficients, lower_bound - 1));
};
}
// Weighted sum == constant reified.
@@ -308,14 +339,13 @@ inline std::function<void(Model*)> WeightedSumNotEqual(
const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
int64 value) {
return [=](Model* model) {
// We creates two extra Boolean variables in this case.
// Exactly one of these alternative must be true.
const Literal is_lt = Literal(model->Add(NewBooleanVariable()), true);
const Literal is_gt = Literal(model->Add(NewBooleanVariable()), true);
model->Add(ClauseConstraint({is_lt, is_gt}));
model->Add(
WeightedSumLowerOrEqualReif(is_lt, vars, coefficients, value - 1));
model->Add(
WeightedSumGreaterOrEqualReif(is_gt, vars, coefficients, value + 1));
const Literal is_gt = is_lt.Negated();
model->Add(ConditionalWeightedSumLowerOrEqual(is_lt, vars, coefficients,
value - 1));
model->Add(ConditionalWeightedSumGreaterOrEqual(is_gt, vars, coefficients,
value + 1));
};
}

View File

@@ -211,22 +211,6 @@ void PrecedencesPropagator::MarkIntegerVariableAsOptional(IntegerVariable i,
void PrecedencesPropagator::AddArc(IntegerVariable tail, IntegerVariable head,
IntegerValue offset,
IntegerVariable offset_var, LiteralIndex l) {
if (head == tail) {
// A self-arc is either plain SAT or plan UNSAT or it forces something on
// the given offset_var or l. In any case it could be presolved in something
// more efficent.
LOG(WARNING) << "Self arc! This could be presolved. "
<< "var:" << tail << " offset:" << offset
<< " offset_var:" << offset_var << " conditioned_by:" << l;
if (offset <= 0 && offset_var == kNoIntegerVariable &&
l == kNoLiteralIndex) {
return; // no-op.
}
}
AdjustSizeFor(tail);
AdjustSizeFor(head);
if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var);
// Handle level zero stuff.
DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
if (l != kNoLiteralIndex) {
@@ -238,6 +222,24 @@ void PrecedencesPropagator::AddArc(IntegerVariable tail, IntegerVariable head,
}
}
if (head == tail) {
// A self-arc is either plain SAT or plan UNSAT or it forces something on
// the given offset_var or l. In any case it could be presolved in something
// more efficent.
LOG(WARNING) << "Self arc! This could be presolved. "
<< "var:" << tail << " offset:" << offset
<< " offset_var:" << offset_var << " conditioned_by:" << l;
if (offset_var == kNoIntegerVariable) {
// Always false => l is false, otherwise this is a no op.
if (offset > 0) trail_->EnqueueWithUnitReason(Literal(l).Negated());
return;
}
}
AdjustSizeFor(tail);
AdjustSizeFor(head);
if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var);
if (l != kNoLiteralIndex && l.value() >= potential_arcs_.size()) {
potential_arcs_.resize(l.value() + 1);
}

View File

@@ -87,10 +87,11 @@ class PrecedencesPropagator : public Propagator {
// when I wrote this, I just had a couple of problems to test this on.
void AddPrecedenceWithVariableOffset(IntegerVariable i1, IntegerVariable i2,
IntegerVariable offset_var);
void AddPrecedenceWithVariableAndFixedOffset(IntegerVariable i1,
IntegerVariable i2,
IntegerValue offset,
IntegerVariable offset_var);
// Generic function that cover all of the above case and more.
void AddPrecedenceWithAllOptions(IntegerVariable i1, IntegerVariable i2,
IntegerValue offset,
IntegerVariable offset_var, LiteralIndex l);
// An optional integer variable has a special behavior:
// - If the bounds on i cross each other, then is_present must be false.
@@ -312,10 +313,10 @@ inline void PrecedencesPropagator::AddPrecedenceWithVariableOffset(
AddArc(i1, i2, /*offset=*/IntegerValue(0), offset_var, /*l=*/kNoLiteralIndex);
}
inline void PrecedencesPropagator::AddPrecedenceWithVariableAndFixedOffset(
inline void PrecedencesPropagator::AddPrecedenceWithAllOptions(
IntegerVariable i1, IntegerVariable i2, IntegerValue offset,
IntegerVariable offset_var) {
AddArc(i1, i2, offset, offset_var, /*l=*/kNoLiteralIndex);
IntegerVariable offset_var, LiteralIndex r) {
AddArc(i1, i2, offset, offset_var, r);
}
// =============================================================================
@@ -347,15 +348,36 @@ inline std::function<void(Model*)> Sum2LowerOrEqual(IntegerVariable a,
return LowerOrEqualWithOffset(a, NegationOf(b), -ub);
}
// l => (a + b <= ub).
inline std::function<void(Model*)> ConditionalSum2LowerOrEqual(
IntegerVariable a, IntegerVariable b, int64 ub, Literal l) {
return [=](Model* model) {
PrecedencesPropagator* p = model->GetOrCreate<PrecedencesPropagator>();
p->AddPrecedenceWithAllOptions(a, NegationOf(b), IntegerValue(-ub),
kNoIntegerVariable, l.Index());
};
}
// a + b + c <= ub.
inline std::function<void(Model*)> Sum3LowerOrEqual(IntegerVariable a,
IntegerVariable b,
IntegerVariable c,
int64 ub) {
return [=](Model* model) {
return model->GetOrCreate<PrecedencesPropagator>()
->AddPrecedenceWithVariableAndFixedOffset(a, NegationOf(c),
IntegerValue(-ub), b);
PrecedencesPropagator* p = model->GetOrCreate<PrecedencesPropagator>();
p->AddPrecedenceWithAllOptions(a, NegationOf(c), IntegerValue(-ub), b,
kNoLiteralIndex);
};
}
// l => (a + b + c <= ub).
inline std::function<void(Model*)> ConditionalSum3LowerOrEqual(
IntegerVariable a, IntegerVariable b, IntegerVariable c, int64 ub,
Literal l) {
return [=](Model* model) {
PrecedencesPropagator* p = model->GetOrCreate<PrecedencesPropagator>();
p->AddPrecedenceWithAllOptions(a, NegationOf(c), IntegerValue(-ub), b,
l.Index());
};
}
@@ -399,7 +421,7 @@ inline std::function<void(Model*)> ReifiedLowerOrEqualWithOffset(
};
}
// is_eq <=> (a + offset == b).
// is_eq <=> (a == b).
inline std::function<void(Model*)> ReifiedEquality(IntegerVariable a,
IntegerVariable b,
Literal is_eq) {
@@ -421,12 +443,11 @@ inline std::function<void(Model*)> ReifiedEquality(IntegerVariable a,
inline std::function<void(Model*)> NotEqual(IntegerVariable a,
IntegerVariable b) {
return [=](Model* model) {
// We model this by is_le and is_ge cannot be both true.
const Literal is_le = Literal(model->Add(NewBooleanVariable()), true);
const Literal is_ge = Literal(model->Add(NewBooleanVariable()), true);
model->Add(Implication(is_le, is_ge.Negated()));
model->Add(ReifiedLowerOrEqualWithOffset(a, b, 0, is_le));
model->Add(ReifiedLowerOrEqualWithOffset(b, a, 0, is_ge));
// We have two options (is_gt or is_lt) and one must be true.
const Literal is_lt = Literal(model->Add(NewBooleanVariable()), true);
const Literal is_gt = is_lt.Negated();
model->Add(ConditionalLowerOrEqualWithOffset(a, b, 1, is_lt));
model->Add(ConditionalLowerOrEqualWithOffset(b, a, 1, is_gt));
};
}

View File

@@ -117,6 +117,45 @@ SortedDisjointIntervalList::Iterator SortedDisjointIntervalList::InsertInterval(
return it;
}
SortedDisjointIntervalList::Iterator SortedDisjointIntervalList::GrowRightByOne(
int64 value, int64* newly_covered) {
auto it = intervals_.upper_bound({value, kint64max});
auto it_prev = it;
// No interval containing or adjacent to "value" on the left (i.e. below).
if (it == begin() || ((--it_prev)->end < value - 1 && value != kint64min)) {
*newly_covered = value;
if (it == end() || it->start != value + 1) {
// No interval adjacent to "value" on the right: insert a singleton.
return intervals_.insert(it, {value, value});
} else {
// There is an interval adjacent to "value" on the right. Extend it by
// one. Note that we already know that there won't be a merge with another
// interval on the left, since there were no interval adjacent to "value"
// on the left.
DCHECK_EQ(it->start, value + 1);
const_cast<Interval*>(&(*it))->start = value;
return it;
}
}
// At this point, "it_prev" points to an interval containing or adjacent to
// "value" on the left: grow it by one, and if it now touches the next
// interval, merge with it.
CHECK_NE(kint64max, it_prev->end) << "Cannot grow right by one: the interval "
"that would grow already ends at "
"kint64max";
*newly_covered = it_prev->end + 1;
if (it != end() && it_prev->end + 2 == it->start) {
// We need to merge it_prev with 'it'.
const_cast<Interval*>(&(*it_prev))->end = it->end;
intervals_.erase(it);
} else {
const_cast<Interval*>(&(*it_prev))->end = it_prev->end + 1;
}
return it_prev;
}
template <class T>
void SortedDisjointIntervalList::InsertAll(const std::vector<T>& starts,
const std::vector<T>& ends) {

View File

@@ -72,6 +72,15 @@ class SortedDisjointIntervalList {
// If start > end, it does LOG(DFATAL) and returns end() (no interval added).
Iterator InsertInterval(int64 start, int64 end);
// If value is in an interval, increase its end by one, otherwise insert the
// interval [value, value]. In both cases, this returns an iterator to the
// new/modified interval (possibly merged with others) and fills newly_covered
// with the new value that was just added in the union of all the intervals.
//
// If this causes an interval ending at kint64max to grow, it will die with a
// CHECK fail.
Iterator GrowRightByOne(int64 value, int64* newly_covered);
// Adds all intervals [starts[i]..ends[i]]. Same behavior as InsertInterval()
// upon invalid intervals. There's a version with int64 and int32.
void InsertIntervals(const std::vector<int64>& starts, const std::vector<int64>& ends);