diff --git a/src/flatzinc/sat_fz_solver.cc b/src/flatzinc/sat_fz_solver.cc index 658f7e87a5..457674aeb8 100644 --- a/src/flatzinc/sat_fz_solver.cc +++ b/src/flatzinc/sat_fz_solver.cc @@ -1051,7 +1051,7 @@ void SolveWithSat(const fz::Model& fz_model, const fz::FlatzincParameters& p, FZLOG << "Num integer variables = " << m.model.Get()->NumIntegerVariables() / 2 << FZENDL; FZLOG << "Num fully encoded variable = " << num_fully_encoded_variables / 2 - << FZENDL; + << FZENDL; FZLOG << "Num Boolean variables created = " << m.model.Get()->NumVariables() << FZENDL; FZLOG << "Num constants = " << m.constant_map.size() << FZENDL; diff --git a/src/sat/cp_constraints.cc b/src/sat/cp_constraints.cc index 313bec4576..9786a9509b 100644 --- a/src/sat/cp_constraints.cc +++ b/src/sat/cp_constraints.cc @@ -72,6 +72,117 @@ void BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) { } } +AllDifferentBoundsPropagator::AllDifferentBoundsPropagator( + const std::vector& vars, IntegerTrail* integer_trail) + : vars_(vars), integer_trail_(integer_trail), num_calls_(0) { + for (int i = 0; i < vars.size(); ++i) { + negated_vars_.push_back(NegationOf(vars_[i])); + } +} + +bool AllDifferentBoundsPropagator::Propagate(Trail* trail) { + if (vars_.empty()) return true; + if (!PropagateLowerBounds(trail)) return false; + + // Note that it is not required to swap back vars_ and negated_vars_. + // TODO(user): investigate the impact. + std::swap(vars_, negated_vars_); + const bool result = PropagateLowerBounds(trail); + std::swap(vars_, negated_vars_); + return result; +} + +// TODO(user): we could gain by pushing all the new bound at the end, so that +// we just have to sort to_insert_ once. +void AllDifferentBoundsPropagator::FillHallReason(IntegerValue hall_lb, + IntegerValue hall_ub) { + for (auto entry : to_insert_) { + value_to_variable_[entry.first] = entry.second; + } + to_insert_.clear(); + integer_reason_.clear(); + for (int64 v = hall_lb.value(); v <= hall_ub; ++v) { + const IntegerVariable var = FindOrDie(value_to_variable_, v); + integer_reason_.push_back(IntegerLiteral::GreaterOrEqual(var, hall_lb)); + integer_reason_.push_back(IntegerLiteral::LowerOrEqual(var, hall_ub)); + } +} + +bool AllDifferentBoundsPropagator::PropagateLowerBounds(Trail* trail) { + ++num_calls_; + critical_intervals_.clear(); + hall_starts_.clear(); + hall_ends_.clear(); + + to_insert_.clear(); + if (num_calls_ % 20 == 0) { + // We don't really need to clear this, but we do from time to time to + // save memory (in case the variable domains are huge). This optimization + // helps a bit. + value_to_variable_.clear(); + } + + // Loop over the variables by increasing ub. + std::sort( + vars_.begin(), vars_.end(), [this](IntegerVariable a, IntegerVariable b) { + return integer_trail_->UpperBound(a) < integer_trail_->UpperBound(b); + }); + for (const IntegerVariable var : vars_) { + const IntegerValue lb = integer_trail_->LowerBound(var); + + // Check if lb is in an Hall interval, and push it if this is the case. + const int hall_index = + std::lower_bound(hall_ends_.begin(), hall_ends_.end(), lb) - + hall_ends_.begin(); + if (hall_index < hall_ends_.size() && hall_starts_[hall_index] <= lb) { + const IntegerValue hs = hall_starts_[hall_index]; + const IntegerValue he = hall_ends_[hall_index]; + FillHallReason(hs, he); + integer_reason_.push_back(IntegerLiteral::GreaterOrEqual(var, hs)); + if (!integer_trail_->Enqueue(IntegerLiteral::GreaterOrEqual(var, he + 1), + /*literal_reason=*/{}, integer_reason_)) { + return false; + } + } + + // Updates critical_intervals_. Note that we use the old lb, but that + // doesn't change the value of newly_covered. This block is what takes the + // most time. + int64 newly_covered; + const auto it = + critical_intervals_.GrowRightByOne(lb.value(), &newly_covered); + to_insert_.push_back({newly_covered, var}); + const IntegerValue end(it->end); + + // We cannot have a conflict, because it should have beend detected before + // by pushing an interval lower bound past its upper bound. + DCHECK_LE(end, integer_trail_->UpperBound(var)); + + // If we have a new Hall interval, add it to the set. Note that it will + // always be last, and if it overlaps some previous Hall intervals, it + // always overlaps them fully. + if (end == integer_trail_->UpperBound(var)) { + const IntegerValue start(it->start); + while (!hall_starts_.empty() && start <= hall_starts_.back()) { + hall_starts_.pop_back(); + hall_ends_.pop_back(); + } + DCHECK(hall_ends_.empty() || hall_ends_.back() < start); + hall_starts_.push_back(start); + hall_ends_.push_back(end); + } + } + return true; +} + +void AllDifferentBoundsPropagator::RegisterWith( + GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + for (const IntegerVariable& var : vars_) { + watcher->WatchIntegerVariable(var, id); + } +} + std::function AllDifferent(const std::vector& vars) { return [=](Model* model) { hash_set fixed_values; diff --git a/src/sat/cp_constraints.h b/src/sat/cp_constraints.h index 9408357817..1070e3123f 100644 --- a/src/sat/cp_constraints.h +++ b/src/sat/cp_constraints.h @@ -14,8 +14,11 @@ #ifndef OR_TOOLS_SAT_CP_CONSTRAINTS_H_ #define OR_TOOLS_SAT_CP_CONSTRAINTS_H_ +#include + #include "sat/integer.h" #include "sat/model.h" +#include "util/sorted_interval_list.h" namespace operations_research { namespace sat { @@ -42,13 +45,70 @@ class BooleanXorPropagator : public PropagatorInterface { DISALLOW_COPY_AND_ASSIGN(BooleanXorPropagator); }; +// Implement the all different bound consistent propagator with explanation. +// That is, given n variables that must be all different, this propagates the +// bounds of each variables as much as possible. The key is to detect the so +// called Hall interval which are interval of size k that contains the domain +// of k variables. Because all the variables must take different values, we can +// deduce that the domain of the other variables cannot contains such Hall +// interval. +// +// We use a "simple" O(n log n) algorithm. +// +// TODO(user): implement the faster algorithm described in: +// https://cs.uwaterloo.ca/~vanbeek/Publications/ijcai03_TR.pdf +// Note that the algorithms are similar, the gain comes by replacing our +// SortedDisjointIntervalList with a more customized class for our operations. +// It is even possible to get an O(n) complexity if the values of the bounds are +// in a range of size O(n). +class AllDifferentBoundsPropagator : public PropagatorInterface { + public: + AllDifferentBoundsPropagator(const std::vector& vars, + IntegerTrail* integer_trail); + + bool Propagate(Trail* trail) final; + void RegisterWith(GenericLiteralWatcher* watcher); + + private: + // Fills integer_reason_ with the reason why we have the given hall interval. + void FillHallReason(IntegerValue hall_lb, IntegerValue hall_ub); + + // Do half the job of Propagate(). + bool PropagateLowerBounds(Trail* trail); + + std::vector vars_; + std::vector negated_vars_; + IntegerTrail* integer_trail_; + + // The sets of "critical" intervals. This has the same meaning as in the + // disjunctive constraint. + SortedDisjointIntervalList critical_intervals_; + + // The list of Hall intervalls detected so far, sorted. + std::vector hall_starts_; + std::vector hall_ends_; + + // Members needed for explaining the propagation. + // + // The IntegerVariable in an hall interval [lb, ub] are the variables with key + // in [lb, ub] in this map. Note(user): if the set of bounds is small, we + // could use a vector here. The O(ub - lb) to create the reason is fine since + // this is the size of the reason. + // + // Optimization: we only insert the entry in the map lazily when the reason + // is needed. + int64 num_calls_; + std::vector> to_insert_; + std::unordered_map value_to_variable_; + std::vector integer_reason_; + + DISALLOW_COPY_AND_ASSIGN(AllDifferentBoundsPropagator); +}; + // ============================================================================ // Model based functions. // ============================================================================ -// Enforces that the given tuple of variables takes different values. -std::function AllDifferent(const std::vector& vars); - // Enforces the XOR of a set of literals to be equal to the given value. inline std::function LiteralXorIs(const std::vector& literals, bool value) { @@ -61,6 +121,28 @@ inline std::function LiteralXorIs(const std::vector& lite }; } +// Enforces that the given tuple of variables takes different values. +std::function AllDifferent(const std::vector& vars); + +// Enforces that the given tuple of variables takes different values. +// Same as AllDifferent() but use a different propagator that only enforce +// the so called "bound consistency" on the variable domains. +// +// Compared to AllDifferent() this doesn't require fully encoding the variables +// and it is also quite fast. Note that the propagation is different, this will +// not remove already taken values from inside a domain, but it will propagates +// more the domain bounds. +inline std::function AllDifferentOnBounds( + const std::vector& vars) { + return [=](Model* model) { + IntegerTrail* integer_trail = model->GetOrCreate(); + AllDifferentBoundsPropagator* constraint = + new AllDifferentBoundsPropagator(vars, integer_trail); + constraint->RegisterWith(model->GetOrCreate()); + model->TakeOwnership(constraint); + }; +} + } // namespace sat } // namespace operations_research diff --git a/src/sat/integer.cc b/src/sat/integer.cc index 6c3fad42b1..1159b13bf6 100644 --- a/src/sat/integer.cc +++ b/src/sat/integer.cc @@ -126,7 +126,7 @@ void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit, // Associate the new literal to i_lit. AddImplications(i_lit, literal); - reverse_encoding_[literal.Index()] = i_lit; + reverse_encoding_[literal.Index()].push_back(i_lit); // Add its negation and associated it with i_lit.Negated(). // @@ -134,7 +134,7 @@ void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit, // 100% sure why!! I think it works because these literals can only appear // in a conflict if the presence literal of the optional variables is true. AddImplications(i_lit.Negated(), literal.Negated()); - reverse_encoding_[literal.NegatedIndex()] = i_lit.Negated(); + reverse_encoding_[literal.NegatedIndex()].push_back(i_lit.Negated()); } Literal IntegerEncoder::CreateAssociatedLiteral(IntegerLiteral i_lit) { @@ -232,8 +232,7 @@ bool IntegerTrail::Propagate(Trail* trail) { const Literal literal = (*trail)[propagation_trail_index_++]; // Bound encoder. - const IntegerLiteral i_lit = encoder_->GetIntegerLiteral(literal); - if (i_lit.var >= 0) { + for (const IntegerLiteral i_lit : encoder_->GetIntegerLiterals(literal)) { // The reason is simply the associated literal. if (!Enqueue(i_lit, {literal.Negated()}, {})) return false; } diff --git a/src/sat/integer.h b/src/sat/integer.h index ac5ad83575..b02dcc8c7b 100644 --- a/src/sat/integer.h +++ b/src/sat/integer.h @@ -141,6 +141,8 @@ inline std::ostream& operator<<(std::ostream& os, IntegerLiteral i_lit) { return os; } +using InlinedIntegerLiteralVector = std::vector; + // Each integer variable x will be associated with a set of literals encoding // (x >= v) for some values of v. This class maintains the relationship between // the integer variables and such literals which can be created by a call to @@ -260,11 +262,11 @@ class IntegerEncoder { // Same as CreateAssociatedLiteral() but safe to call if already created. Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit); - // Returns the IntegerLiteral that was associated with the given Boolean - // literal or an IntegerLiteral with a variable set to kNoIntegerVariable if - // the argument does not correspond to such literal. - IntegerLiteral GetIntegerLiteral(Literal lit) const { - if (lit.Index() >= reverse_encoding_.size()) return IntegerLiteral(); + // Returns the IntegerLiterals that were associated with the given Literal. + const InlinedIntegerLiteralVector& GetIntegerLiterals(Literal lit) const { + if (lit.Index() >= reverse_encoding_.size()) { + return empty_integer_literal_vector_; + } return reverse_encoding_[lit.Index()]; } @@ -291,10 +293,9 @@ class IntegerEncoder { // corresponding to the same variable). ITIVector> encoding_by_var_; - // Store for a given LiteralIndex its associated IntegerLiteral or an - // IntegerLiteral with kNoIntegerVariable as a variable if the LiteralIndex - // doesn't correspond to an IntegerLiteral. - ITIVector reverse_encoding_; + // Store for a given LiteralIndex the list of its associated IntegerLiterals. + const InlinedIntegerLiteralVector empty_integer_literal_vector_; + ITIVector reverse_encoding_; // Full domain encoding. The map contains the index in full_encoding_ of // the fully encoded variable. Each entry in full_encoding_ is sorted by @@ -744,12 +745,7 @@ inline std::function Equality(IntegerVariable v, int64 value) { inline std::function Equality(IntegerLiteral i, Literal l) { return [=](Model* model) { IntegerEncoder* encoder = model->GetOrCreate(); - - // Tricky: currently we cannot associate the same literal to two different - // IntegerLiteral! The second test verifies that l is not already - // associated. - if (encoder->LiteralIsAssociated(i) || - encoder->GetIntegerLiteral(l) != IntegerLiteral()) { + if (encoder->LiteralIsAssociated(i)) { const Literal current = encoder->GetOrCreateAssociatedLiteral(i); model->Add(Equality(current, l)); } else { diff --git a/src/sat/integer_expr.h b/src/sat/integer_expr.h index eedf328cf4..bb0370dab2 100644 --- a/src/sat/integer_expr.h +++ b/src/sat/integer_expr.h @@ -199,7 +199,6 @@ inline std::function WeightedSumLowerOrEqual( const std::vector& vars, const VectorInt& coefficients, int64 upper_bound) { // Special cases. - // TODO(user): Do the same for the reified case. CHECK_GE(vars.size(), 1) << "Should be encoded differently."; if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) && (coefficients[1] == 1 || coefficients[1] == -1)) { @@ -230,7 +229,7 @@ template inline std::function WeightedSumGreaterOrEqual( const std::vector& vars, const VectorInt& coefficients, int64 lower_bound) { - // We just negate everything and use an IntegerSumLE() constraints. + // We just negate everything and use an <= constraints. std::vector negated_coeffs(coefficients.begin(), coefficients.end()); for (IntegerValue& ref : negated_coeffs) ref = -ref; return WeightedSumLowerOrEqual(vars, negated_coeffs, -lower_bound); @@ -247,33 +246,61 @@ inline std::function FixedWeightedSum( }; } +// is_le => sum <= upper_bound +template +inline std::function ConditionalWeightedSumLowerOrEqual( + Literal is_le, const std::vector& vars, + const VectorInt& coefficients, int64 upper_bound) { + // Special cases. + CHECK_GE(vars.size(), 1) << "Should be encoded differently."; + if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) && + (coefficients[1] == 1 || coefficients[1] == -1)) { + return ConditionalSum2LowerOrEqual( + coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]), + coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]), upper_bound, + is_le); + } + if (vars.size() == 3 && (coefficients[0] == 1 || coefficients[0] == -1) && + (coefficients[1] == 1 || coefficients[1] == -1) && + (coefficients[2] == 1 || coefficients[2] == -1)) { + return ConditionalSum3LowerOrEqual( + coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]), + coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]), + coefficients[2] == 1 ? vars[2] : NegationOf(vars[2]), upper_bound, + is_le); + } + return [=](Model* model) { + IntegerSumLE* constraint = new IntegerSumLE( + is_le.Index(), vars, + std::vector(coefficients.begin(), coefficients.end()), + IntegerValue(upper_bound), model->GetOrCreate()); + constraint->RegisterWith(model->GetOrCreate()); + model->TakeOwnership(constraint); + }; +} + +// is_ge => sum >= lower_bound +template +inline std::function ConditionalWeightedSumGreaterOrEqual( + Literal is_ge, const std::vector& vars, + const VectorInt& coefficients, int64 lower_bound) { + // We just negate everything and use an <= constraint. + std::vector negated_coeffs(coefficients.begin(), coefficients.end()); + for (IntegerValue& ref : negated_coeffs) ref = -ref; + return ConditionalWeightedSumLowerOrEqual(is_ge, vars, negated_coeffs, + -lower_bound); +} + // Weighted sum <= constant reified. template inline std::function WeightedSumLowerOrEqualReif( Literal is_le, const std::vector& vars, const VectorInt& coefficients, int64 upper_bound) { return [=](Model* model) { - // is_le => lin <= upper_bound - { - IntegerSumLE* constraint = new IntegerSumLE( - is_le.Index(), vars, - std::vector(coefficients.begin(), coefficients.end()), - IntegerValue(upper_bound), model->GetOrCreate()); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); - } - - // not(is_le) => lin > upper_bound, i.e -lin <= -upper_bound - 1 - { - std::vector negated_coeffs(coefficients.begin(), - coefficients.end()); - for (IntegerValue& ref : negated_coeffs) ref = -ref; - IntegerSumLE* constraint = new IntegerSumLE( - is_le.NegatedIndex(), vars, negated_coeffs, - IntegerValue(-upper_bound - 1), model->GetOrCreate()); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); - } + model->Add(ConditionalWeightedSumLowerOrEqual(is_le, vars, coefficients, + upper_bound)); + model->Add(ConditionalWeightedSumGreaterOrEqual( + is_le.Negated(), vars, coefficients, upper_bound + 1)); }; } @@ -282,8 +309,12 @@ template inline std::function WeightedSumGreaterOrEqualReif( Literal is_ge, const std::vector& vars, const VectorInt& coefficients, int64 lower_bound) { - return WeightedSumLowerOrEqualReif(is_ge.Negated(), vars, coefficients, - lower_bound - 1); + return [=](Model* model) { + model->Add(ConditionalWeightedSumGreaterOrEqual(is_ge, vars, coefficients, + lower_bound)); + model->Add(ConditionalWeightedSumLowerOrEqual( + is_ge.Negated(), vars, coefficients, lower_bound - 1)); + }; } // Weighted sum == constant reified. @@ -308,14 +339,13 @@ inline std::function WeightedSumNotEqual( const std::vector& vars, const VectorInt& coefficients, int64 value) { return [=](Model* model) { - // We creates two extra Boolean variables in this case. + // Exactly one of these alternative must be true. const Literal is_lt = Literal(model->Add(NewBooleanVariable()), true); - const Literal is_gt = Literal(model->Add(NewBooleanVariable()), true); - model->Add(ClauseConstraint({is_lt, is_gt})); - model->Add( - WeightedSumLowerOrEqualReif(is_lt, vars, coefficients, value - 1)); - model->Add( - WeightedSumGreaterOrEqualReif(is_gt, vars, coefficients, value + 1)); + const Literal is_gt = is_lt.Negated(); + model->Add(ConditionalWeightedSumLowerOrEqual(is_lt, vars, coefficients, + value - 1)); + model->Add(ConditionalWeightedSumGreaterOrEqual(is_gt, vars, coefficients, + value + 1)); }; } diff --git a/src/sat/precedences.cc b/src/sat/precedences.cc index c40dff56c8..b4754a0f7a 100644 --- a/src/sat/precedences.cc +++ b/src/sat/precedences.cc @@ -211,22 +211,6 @@ void PrecedencesPropagator::MarkIntegerVariableAsOptional(IntegerVariable i, void PrecedencesPropagator::AddArc(IntegerVariable tail, IntegerVariable head, IntegerValue offset, IntegerVariable offset_var, LiteralIndex l) { - if (head == tail) { - // A self-arc is either plain SAT or plan UNSAT or it forces something on - // the given offset_var or l. In any case it could be presolved in something - // more efficent. - LOG(WARNING) << "Self arc! This could be presolved. " - << "var:" << tail << " offset:" << offset - << " offset_var:" << offset_var << " conditioned_by:" << l; - if (offset <= 0 && offset_var == kNoIntegerVariable && - l == kNoLiteralIndex) { - return; // no-op. - } - } - AdjustSizeFor(tail); - AdjustSizeFor(head); - if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var); - // Handle level zero stuff. DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); if (l != kNoLiteralIndex) { @@ -238,6 +222,24 @@ void PrecedencesPropagator::AddArc(IntegerVariable tail, IntegerVariable head, } } + if (head == tail) { + // A self-arc is either plain SAT or plan UNSAT or it forces something on + // the given offset_var or l. In any case it could be presolved in something + // more efficent. + LOG(WARNING) << "Self arc! This could be presolved. " + << "var:" << tail << " offset:" << offset + << " offset_var:" << offset_var << " conditioned_by:" << l; + if (offset_var == kNoIntegerVariable) { + // Always false => l is false, otherwise this is a no op. + if (offset > 0) trail_->EnqueueWithUnitReason(Literal(l).Negated()); + return; + } + } + + AdjustSizeFor(tail); + AdjustSizeFor(head); + if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var); + if (l != kNoLiteralIndex && l.value() >= potential_arcs_.size()) { potential_arcs_.resize(l.value() + 1); } diff --git a/src/sat/precedences.h b/src/sat/precedences.h index 49df2f5597..7bc9f15b40 100644 --- a/src/sat/precedences.h +++ b/src/sat/precedences.h @@ -87,10 +87,11 @@ class PrecedencesPropagator : public Propagator { // when I wrote this, I just had a couple of problems to test this on. void AddPrecedenceWithVariableOffset(IntegerVariable i1, IntegerVariable i2, IntegerVariable offset_var); - void AddPrecedenceWithVariableAndFixedOffset(IntegerVariable i1, - IntegerVariable i2, - IntegerValue offset, - IntegerVariable offset_var); + + // Generic function that cover all of the above case and more. + void AddPrecedenceWithAllOptions(IntegerVariable i1, IntegerVariable i2, + IntegerValue offset, + IntegerVariable offset_var, LiteralIndex l); // An optional integer variable has a special behavior: // - If the bounds on i cross each other, then is_present must be false. @@ -312,10 +313,10 @@ inline void PrecedencesPropagator::AddPrecedenceWithVariableOffset( AddArc(i1, i2, /*offset=*/IntegerValue(0), offset_var, /*l=*/kNoLiteralIndex); } -inline void PrecedencesPropagator::AddPrecedenceWithVariableAndFixedOffset( +inline void PrecedencesPropagator::AddPrecedenceWithAllOptions( IntegerVariable i1, IntegerVariable i2, IntegerValue offset, - IntegerVariable offset_var) { - AddArc(i1, i2, offset, offset_var, /*l=*/kNoLiteralIndex); + IntegerVariable offset_var, LiteralIndex r) { + AddArc(i1, i2, offset, offset_var, r); } // ============================================================================= @@ -347,15 +348,36 @@ inline std::function Sum2LowerOrEqual(IntegerVariable a, return LowerOrEqualWithOffset(a, NegationOf(b), -ub); } +// l => (a + b <= ub). +inline std::function ConditionalSum2LowerOrEqual( + IntegerVariable a, IntegerVariable b, int64 ub, Literal l) { + return [=](Model* model) { + PrecedencesPropagator* p = model->GetOrCreate(); + p->AddPrecedenceWithAllOptions(a, NegationOf(b), IntegerValue(-ub), + kNoIntegerVariable, l.Index()); + }; +} + // a + b + c <= ub. inline std::function Sum3LowerOrEqual(IntegerVariable a, IntegerVariable b, IntegerVariable c, int64 ub) { return [=](Model* model) { - return model->GetOrCreate() - ->AddPrecedenceWithVariableAndFixedOffset(a, NegationOf(c), - IntegerValue(-ub), b); + PrecedencesPropagator* p = model->GetOrCreate(); + p->AddPrecedenceWithAllOptions(a, NegationOf(c), IntegerValue(-ub), b, + kNoLiteralIndex); + }; +} + +// l => (a + b + c <= ub). +inline std::function ConditionalSum3LowerOrEqual( + IntegerVariable a, IntegerVariable b, IntegerVariable c, int64 ub, + Literal l) { + return [=](Model* model) { + PrecedencesPropagator* p = model->GetOrCreate(); + p->AddPrecedenceWithAllOptions(a, NegationOf(c), IntegerValue(-ub), b, + l.Index()); }; } @@ -399,7 +421,7 @@ inline std::function ReifiedLowerOrEqualWithOffset( }; } -// is_eq <=> (a + offset == b). +// is_eq <=> (a == b). inline std::function ReifiedEquality(IntegerVariable a, IntegerVariable b, Literal is_eq) { @@ -421,12 +443,11 @@ inline std::function ReifiedEquality(IntegerVariable a, inline std::function NotEqual(IntegerVariable a, IntegerVariable b) { return [=](Model* model) { - // We model this by is_le and is_ge cannot be both true. - const Literal is_le = Literal(model->Add(NewBooleanVariable()), true); - const Literal is_ge = Literal(model->Add(NewBooleanVariable()), true); - model->Add(Implication(is_le, is_ge.Negated())); - model->Add(ReifiedLowerOrEqualWithOffset(a, b, 0, is_le)); - model->Add(ReifiedLowerOrEqualWithOffset(b, a, 0, is_ge)); + // We have two options (is_gt or is_lt) and one must be true. + const Literal is_lt = Literal(model->Add(NewBooleanVariable()), true); + const Literal is_gt = is_lt.Negated(); + model->Add(ConditionalLowerOrEqualWithOffset(a, b, 1, is_lt)); + model->Add(ConditionalLowerOrEqualWithOffset(b, a, 1, is_gt)); }; } diff --git a/src/util/sorted_interval_list.cc b/src/util/sorted_interval_list.cc index e2bd25a233..1d4c94b331 100644 --- a/src/util/sorted_interval_list.cc +++ b/src/util/sorted_interval_list.cc @@ -117,6 +117,45 @@ SortedDisjointIntervalList::Iterator SortedDisjointIntervalList::InsertInterval( return it; } +SortedDisjointIntervalList::Iterator SortedDisjointIntervalList::GrowRightByOne( + int64 value, int64* newly_covered) { + auto it = intervals_.upper_bound({value, kint64max}); + auto it_prev = it; + + // No interval containing or adjacent to "value" on the left (i.e. below). + if (it == begin() || ((--it_prev)->end < value - 1 && value != kint64min)) { + *newly_covered = value; + if (it == end() || it->start != value + 1) { + // No interval adjacent to "value" on the right: insert a singleton. + return intervals_.insert(it, {value, value}); + } else { + // There is an interval adjacent to "value" on the right. Extend it by + // one. Note that we already know that there won't be a merge with another + // interval on the left, since there were no interval adjacent to "value" + // on the left. + DCHECK_EQ(it->start, value + 1); + const_cast(&(*it))->start = value; + return it; + } + } + + // At this point, "it_prev" points to an interval containing or adjacent to + // "value" on the left: grow it by one, and if it now touches the next + // interval, merge with it. + CHECK_NE(kint64max, it_prev->end) << "Cannot grow right by one: the interval " + "that would grow already ends at " + "kint64max"; + *newly_covered = it_prev->end + 1; + if (it != end() && it_prev->end + 2 == it->start) { + // We need to merge it_prev with 'it'. + const_cast(&(*it_prev))->end = it->end; + intervals_.erase(it); + } else { + const_cast(&(*it_prev))->end = it_prev->end + 1; + } + return it_prev; +} + template void SortedDisjointIntervalList::InsertAll(const std::vector& starts, const std::vector& ends) { diff --git a/src/util/sorted_interval_list.h b/src/util/sorted_interval_list.h index f2a591d5c4..d3f87c87b0 100644 --- a/src/util/sorted_interval_list.h +++ b/src/util/sorted_interval_list.h @@ -72,6 +72,15 @@ class SortedDisjointIntervalList { // If start > end, it does LOG(DFATAL) and returns end() (no interval added). Iterator InsertInterval(int64 start, int64 end); + // If value is in an interval, increase its end by one, otherwise insert the + // interval [value, value]. In both cases, this returns an iterator to the + // new/modified interval (possibly merged with others) and fills newly_covered + // with the new value that was just added in the union of all the intervals. + // + // If this causes an interval ending at kint64max to grow, it will die with a + // CHECK fail. + Iterator GrowRightByOne(int64 value, int64* newly_covered); + // Adds all intervals [starts[i]..ends[i]]. Same behavior as InsertInterval() // upon invalid intervals. There's a version with int64 and int32. void InsertIntervals(const std::vector& starts, const std::vector& ends);