diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 478f39bb0a..cf6740eed8 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1147,6 +1147,8 @@ void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { const IntegerValue denom(m->Get(Value(vars[1]))); if (denom == 1) { m->Add(Equality(vars[0], div)); + } else if (denom == -1) { + m->Add(Equality(NegationOf(vars[0]), div)); } else { m->Add(FixedDivisionConstraint(vars[0], denom, div)); } diff --git a/ortools/sat/cumulative_energy.cc b/ortools/sat/cumulative_energy.cc index 3f1580aa9f..48dc9815e8 100644 --- a/ortools/sat/cumulative_energy.cc +++ b/ortools/sat/cumulative_energy.cc @@ -57,14 +57,10 @@ void AddCumulativeOverloadChecker(const std::vector& demands, energies.emplace_back(demand.constant * size.constant); } else if (demand.var == kNoIntegerVariable) { CHECK_GE(demand.constant, 0); - energies.push_back(size); - energies.back().coeff *= demand.constant; - energies.back().constant *= demand.constant; + energies.push_back(size.MultipliedBy(demand.constant)); } else if (size.var == kNoIntegerVariable) { CHECK_GE(size.constant, 0); - energies.push_back(demand); - energies.back().coeff *= size.constant; - energies.back().constant *= size.constant; + energies.push_back(demand.MultipliedBy(size.constant)); } else { // The case where both demand and size are variable should be rare. // diff --git a/ortools/sat/implied_bounds.cc b/ortools/sat/implied_bounds.cc index e8c2cd3150..6bd95dc94e 100644 --- a/ortools/sat/implied_bounds.cc +++ b/ortools/sat/implied_bounds.cc @@ -280,30 +280,29 @@ bool TryToReconcileEncodings( Literal lit1 = size2_enc[1].literal; IntegerValue value1 = size2_enc[1].value * size2_affine.coeff + size2_affine.constant; - for (const auto& [unused_value, candidate_literal] : affine_var_encoding) { + for (const auto& [unused, candidate_literal] : affine_var_encoding) { if (candidate_literal == lit1) { std::swap(lit0, lit1); std::swap(value0, value1); } if (candidate_literal != lit0) continue; - builder->Clear(); - // Compute the minimum energy. IntegerValue min_energy = kMaxIntegerValue; for (const auto& [value, literal] : affine_var_encoding) { - const IntegerValue energy = - literal == lit0 ? value0 * (affine.coeff * value + affine.constant) - : value1 * (affine.coeff * value + affine.constant); + const IntegerValue energy = literal == lit0 + ? value0 * affine.ValueAt(value) + : value1 * affine.ValueAt(value); min_energy = std::min(energy, min_energy); } - builder->AddConstant(min_energy); // Build the energy expression. + builder->Clear(); + builder->AddConstant(min_energy); for (const auto& [value, literal] : affine_var_encoding) { - const IntegerValue energy = - literal == lit0 ? value0 * (affine.coeff * value + affine.constant) - : value1 * (affine.coeff * value + affine.constant); + const IntegerValue energy = literal == lit0 + ? value0 * affine.ValueAt(value) + : value1 * affine.ValueAt(value); if (energy > min_energy) { if (!builder->AddLiteralTerm(literal, energy - min_energy)) { return false; @@ -325,14 +324,14 @@ bool DetectLinearEncodingOfProducts(const AffineExpression& left, IntegerTrail* integer_trail = model->GetOrCreate(); ImpliedBounds* implied_bounds = model->GetOrCreate(); - if (left.IsFixed(integer_trail)) { - const IntegerValue value = left.Value(integer_trail); + if (integer_trail->IsFixed(left)) { + const IntegerValue value = integer_trail->LowerBound(left); builder->AddTerm(right, value); return true; } - if (right.IsFixed(integer_trail)) { - const IntegerValue value = right.Value(integer_trail); + if (integer_trail->IsFixed(right)) { + const IntegerValue value = integer_trail->LowerBound(right); builder->AddTerm(left, value); return true; } @@ -391,19 +390,15 @@ bool DetectLinearEncodingOfProducts(const AffineExpression& left, // Compute the min energy. IntegerValue min_energy = kMaxIntegerValue; for (int i = 0; i < left_encoding.size(); ++i) { - const IntegerValue left_value = left_encoding[i].value; - const IntegerValue right_value = right_encoding[i].value; - const IntegerValue energy = (left.coeff * left_value + left.constant) * - (right.coeff * right_value + right.constant); + const IntegerValue energy = left.ValueAt(left_encoding[i].value) * + right.ValueAt(right_encoding[i].value); min_energy = std::min(min_energy, energy); } // Build the linear formulation of the energy. for (int i = 0; i < left_encoding.size(); ++i) { - const IntegerValue left_value = left_encoding[i].value; - const IntegerValue right_value = right_encoding[i].value; - const IntegerValue energy = (left.coeff * left_value + left.constant) * - (right.coeff * right_value + right.constant); + const IntegerValue energy = left.ValueAt(left_encoding[i].value) * + right.ValueAt(right_encoding[i].value); if (energy == min_energy) continue; DCHECK_GT(energy, min_energy); const Literal lit = left_encoding[i].literal; diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 0789e23168..594fee8387 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -36,40 +36,6 @@ std::vector NegationOf( return result; } -IntegerValue AffineExpression::Min(IntegerTrail* integer_trail) const { - IntegerValue result = constant; - if (var != kNoIntegerVariable) { - if (coeff > 0) { - result += coeff * integer_trail->LowerBound(var); - } else { - result += coeff * integer_trail->UpperBound(var); - } - } - return result; -} - -IntegerValue AffineExpression::Max(IntegerTrail* integer_trail) const { - IntegerValue result = constant; - if (var != kNoIntegerVariable) { - if (coeff > 0) { - result += coeff * integer_trail->UpperBound(var); - } else { - result += coeff * integer_trail->LowerBound(var); - } - } - return result; -} - -bool AffineExpression::IsFixed(IntegerTrail* integer_trail) const { - if (var == kNoIntegerVariable || coeff == 0) return true; - return integer_trail->IsFixed(var); -} - -IntegerValue AffineExpression::Value(IntegerTrail* integer_trail) const { - DCHECK(IsFixed(integer_trail)); - return Max(integer_trail); -} - std::string ValueLiteralPair::DebugString() const { return absl::StrCat("(literal = ", literal.DebugString(), ", value = ", value.value(), ")"); @@ -1040,6 +1006,25 @@ std::string IntegerTrail::DebugString() { return result; } +bool IntegerTrail::UnsafeEnqueue( + IntegerLiteral i_lit, absl::Span literal_reason, + absl::Span integer_reason) { + if (i_lit.IsTrueLiteral()) return true; + + std::vector cleaned_reason; + for (const IntegerLiteral lit : integer_reason) { + DCHECK(!lit.IsFalseLiteral()); + if (lit.IsTrueLiteral()) continue; + cleaned_reason.push_back(lit); + } + + if (i_lit.IsFalseLiteral()) { + return ReportConflict(literal_reason, cleaned_reason); + } else { + return Enqueue(i_lit, literal_reason, cleaned_reason); + } +} + bool IntegerTrail::Enqueue(IntegerLiteral i_lit, absl::Span literal_reason, absl::Span integer_reason) { diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 0cf184a5e2..e51d2db8e2 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -174,6 +174,11 @@ struct IntegerLiteral { static IntegerLiteral GreaterOrEqual(IntegerVariable i, IntegerValue bound); static IntegerLiteral LowerOrEqual(IntegerVariable i, IntegerValue bound); + // These two static integer literals represent an always true and an always + // false condition. + static IntegerLiteral TrueLiteral(); + static IntegerLiteral FalseLiteral(); + // Clients should prefer the static construction methods above. IntegerLiteral() : var(kNoIntegerVariable), bound(0) {} IntegerLiteral(IntegerVariable v, IntegerValue b) : var(v), bound(b) { @@ -182,6 +187,8 @@ struct IntegerLiteral { } bool IsValid() const { return var != kNoIntegerVariable; } + bool IsTrueLiteral() const { return var == kNoIntegerVariable && bound <= 0; } + bool IsFalseLiteral() const { return var == kNoIntegerVariable && bound > 0; } // The negation of x >= bound is x <= bound - 1. IntegerLiteral Negated() const; @@ -210,7 +217,6 @@ inline std::ostream& operator<<(std::ostream& os, IntegerLiteral i_lit) { } using InlinedIntegerLiteralVector = absl::InlinedVector; -class IntegerTrail; // Represents [coeff * variable + constant] or just a [constant]. // @@ -232,23 +238,29 @@ struct AffineExpression { // Returns the integer literal corresponding to expression >= value or // expression <= value. // - // These should not be called on constant expression (CHECKED). + // On constant expressions, they will return IntegerLiteral::TrueLiteral() + // or IntegerLiteral::FalseLiteral(). IntegerLiteral GreaterOrEqual(IntegerValue bound) const; IntegerLiteral LowerOrEqual(IntegerValue bound) const; AffineExpression Negated() const { + if (var == kNoIntegerVariable) return AffineExpression(-constant); return AffineExpression(NegationOf(var), coeff, -constant); } + AffineExpression MultipliedBy(IntegerValue multiplier) const { + // Note that this also works if multiplier is negative. + return AffineExpression(var, coeff * multiplier, constant * multiplier); + } + bool operator==(AffineExpression o) const { return var == o.var && coeff == o.coeff && constant == o.constant; } - // Getters on the bounds of the affine expression. - IntegerValue Min(IntegerTrail* integer_trail) const; - IntegerValue Max(IntegerTrail* integer_trail) const; - IntegerValue Value(IntegerTrail* integer_trail) const; - bool IsFixed(IntegerTrail* integer_trail) const; + // Returns the value of this affine expression given its variable value. + IntegerValue ValueAt(IntegerValue var_value) const { + return coeff * var_value + constant; + } // Returns the affine expression value under a given LP solution. double LpValue( @@ -268,6 +280,9 @@ struct AffineExpression { } // The coefficient MUST be positive. Use NegationOf(var) if needed. + // + // TODO(user): Make this private to enforce the invariant that coeff cannot be + // negative. IntegerVariable var = kNoIntegerVariable; // kNoIntegerVariable for constant. IntegerValue coeff = IntegerValue(0); // Zero for constant. IntegerValue constant = IntegerValue(0); @@ -712,6 +727,12 @@ class IntegerTrail : public SatPropagator { IntegerLiteral LowerBoundAsLiteral(IntegerVariable i) const; IntegerLiteral UpperBoundAsLiteral(IntegerVariable i) const; + // Returns the integer literal that represent the current lower/upper bound of + // the given integer variable. In case the expression is constant, it returns + // IntegerLiteral::TrueLiteral(). + IntegerLiteral LowerBoundAsLiteral(AffineExpression expr) const; + IntegerLiteral UpperBoundAsLiteral(AffineExpression expr) const; + // Returns the current value (if known) of an IntegerLiteral. bool IntegerLiteralIsTrue(IntegerLiteral l) const; bool IntegerLiteralIsFalse(IntegerLiteral l) const; @@ -788,6 +809,18 @@ class IntegerTrail : public SatPropagator { IntegerLiteral i_lit, absl::Span literal_reason, absl::Span integer_reason); + // Enqueue new information about a variable bound. It has the same behavior + // as the Enqueue() method, except that it accepts true and false integer + // literals, both for i_lit, and for the integer reason. + // This method will do nothing if i_lit is a true literal. It will report a + // conflict if i_lit is a false literal, and enqueue i_lit normally otherwise. + // Furthemore, it will check that the integer reason does not contain any + // false literals, and will remove true literals before calling + // ReportConflict() or Enqueue(). + ABSL_MUST_USE_RESULT bool UnsafeEnqueue( + IntegerLiteral i_lit, absl::Span literal_reason, + absl::Span integer_reason); + // Pushes the given integer literal assuming that the Boolean literal is true. // This can do a few things: // - If lit it true, add it to the reason and push the integer bound. @@ -1336,6 +1369,14 @@ inline IntegerLiteral IntegerLiteral::LowerOrEqual(IntegerVariable i, NegationOf(i), bound < kMinIntegerValue ? kMaxIntegerValue + 1 : -bound); } +inline IntegerLiteral IntegerLiteral::TrueLiteral() { + return IntegerLiteral(kNoIntegerVariable, IntegerValue(-1)); +} + +inline IntegerLiteral IntegerLiteral::FalseLiteral() { + return IntegerLiteral(kNoIntegerVariable, IntegerValue(1)); +} + inline IntegerLiteral IntegerLiteral::Negated() const { // Note that bound >= kMinIntegerValue, so -bound + 1 will have the correct // capped value. @@ -1347,7 +1388,10 @@ inline IntegerLiteral IntegerLiteral::Negated() const { // var * coeff + constant >= bound. inline IntegerLiteral AffineExpression::GreaterOrEqual( IntegerValue bound) const { - DCHECK_NE(var, kNoIntegerVariable); + if (var == kNoIntegerVariable) { + return constant >= bound ? IntegerLiteral::TrueLiteral() + : IntegerLiteral::FalseLiteral(); + } DCHECK_GT(coeff, 0); return IntegerLiteral::GreaterOrEqual(var, CeilRatio(bound - constant, coeff)); @@ -1355,7 +1399,10 @@ inline IntegerLiteral AffineExpression::GreaterOrEqual( // var * coeff + constant <= bound. inline IntegerLiteral AffineExpression::LowerOrEqual(IntegerValue bound) const { - DCHECK_NE(var, kNoIntegerVariable); + if (var == kNoIntegerVariable) { + return constant <= bound ? IntegerLiteral::TrueLiteral() + : IntegerLiteral::FalseLiteral(); + } DCHECK_GT(coeff, 0); return IntegerLiteral::LowerOrEqual(var, FloorRatio(bound - constant, coeff)); } @@ -1412,6 +1459,18 @@ inline IntegerLiteral IntegerTrail::UpperBoundAsLiteral( return IntegerLiteral::LowerOrEqual(i, UpperBound(i)); } +inline IntegerLiteral IntegerTrail::LowerBoundAsLiteral( + AffineExpression expr) const { + if (expr.var == kNoIntegerVariable) return IntegerLiteral::TrueLiteral(); + return IntegerLiteral::GreaterOrEqual(expr.var, LowerBound(expr.var)); +} + +inline IntegerLiteral IntegerTrail::UpperBoundAsLiteral( + AffineExpression expr) const { + if (expr.var == kNoIntegerVariable) return IntegerLiteral::TrueLiteral(); + return IntegerLiteral::LowerOrEqual(expr.var, UpperBound(expr.var)); +} + inline bool IntegerTrail::IntegerLiteralIsTrue(IntegerLiteral l) const { return l.bound <= LowerBound(l.var); } diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 04327a0d2d..76eb8d85de 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -871,9 +871,9 @@ void PositiveDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); } -FixedDivisionPropagator::FixedDivisionPropagator(IntegerVariable a, +FixedDivisionPropagator::FixedDivisionPropagator(AffineExpression a, IntegerValue b, - IntegerVariable c, + AffineExpression c, IntegerTrail* integer_trail) : a_(a), b_(b), c_(c), integer_trail_(integer_trail) {} @@ -887,8 +887,9 @@ bool FixedDivisionPropagator::Propagate() { if (max_a / b_ < max_c) { max_c = max_a / b_; - if (!integer_trail_->Enqueue(IntegerLiteral::LowerOrEqual(c_, max_c), {}, - {integer_trail_->UpperBoundAsLiteral(a_)})) { + if (!integer_trail_->UnsafeEnqueue( + c_.LowerOrEqual(max_c), {}, + {integer_trail_->UpperBoundAsLiteral(a_)})) { return false; } } else if (max_a / b_ > max_c) { @@ -896,17 +897,18 @@ bool FixedDivisionPropagator::Propagate() { max_c >= 0 ? max_c * b_ + b_ - 1 : IntegerValue(CapProd(max_c.value(), b_.value())); CHECK_LT(new_max_a, max_a); - if (!integer_trail_->Enqueue(IntegerLiteral::LowerOrEqual(a_, new_max_a), - {}, - {integer_trail_->UpperBoundAsLiteral(c_)})) { + if (!integer_trail_->UnsafeEnqueue( + a_.LowerOrEqual(new_max_a), {}, + {integer_trail_->UpperBoundAsLiteral(c_)})) { return false; } } if (min_a / b_ > min_c) { min_c = min_a / b_; - if (!integer_trail_->Enqueue(IntegerLiteral::GreaterOrEqual(c_, min_c), {}, - {integer_trail_->LowerBoundAsLiteral(a_)})) { + if (!integer_trail_->UnsafeEnqueue( + c_.GreaterOrEqual(min_c), {}, + {integer_trail_->LowerBoundAsLiteral(a_)})) { return false; } } else if (min_a / b_ < min_c) { @@ -914,9 +916,9 @@ bool FixedDivisionPropagator::Propagate() { min_c > 0 ? IntegerValue(CapProd(min_c.value(), b_.value())) : min_c * b_ - b_ + 1; CHECK_GT(new_min_a, min_a); - if (!integer_trail_->Enqueue(IntegerLiteral::GreaterOrEqual(a_, new_min_a), - {}, - {integer_trail_->LowerBoundAsLiteral(c_)})) { + if (!integer_trail_->UnsafeEnqueue( + a_.GreaterOrEqual(new_min_a), {}, + {integer_trail_->LowerBoundAsLiteral(c_)})) { return false; } } @@ -926,8 +928,8 @@ bool FixedDivisionPropagator::Propagate() { void FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - watcher->WatchIntegerVariable(a_, id); - watcher->WatchIntegerVariable(c_, id); + watcher->WatchAffineExpression(a_, id); + watcher->WatchAffineExpression(c_, id); } std::function IsOneOf(IntegerVariable var, diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index d9b8a417ea..381a80f6ff 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -256,16 +256,16 @@ class PositiveDivisionPropagator : public PropagatorInterface { // cases, and we only propagates the bounds. cst_b must be > 0. class FixedDivisionPropagator : public PropagatorInterface { public: - FixedDivisionPropagator(IntegerVariable a, IntegerValue b, IntegerVariable c, - IntegerTrail* integer_trail); + FixedDivisionPropagator(AffineExpression a, IntegerValue b, + AffineExpression c, IntegerTrail* integer_trail); bool Propagate() final; void RegisterWith(GenericLiteralWatcher* watcher); private: - const IntegerVariable a_; + const AffineExpression a_; const IntegerValue b_; - const IntegerVariable c_; + const AffineExpression c_; IntegerTrail* integer_trail_; DISALLOW_COPY_AND_ASSIGN(FixedDivisionPropagator); @@ -821,15 +821,14 @@ inline std::function DivisionConstraint(IntegerVariable num, } // Adds the constraint: a / b = c where b is a constant. -inline std::function FixedDivisionConstraint(IntegerVariable a, +inline std::function FixedDivisionConstraint(AffineExpression a, IntegerValue b, - IntegerVariable c) { + AffineExpression c) { return [=](Model* model) { IntegerTrail* integer_trail = model->GetOrCreate(); FixedDivisionPropagator* constraint = - b > 0 - ? new FixedDivisionPropagator(a, b, c, integer_trail) - : new FixedDivisionPropagator(NegationOf(a), -b, c, integer_trail); + b > 0 ? new FixedDivisionPropagator(a, b, c, integer_trail) + : new FixedDivisionPropagator(a.Negated(), -b, c, integer_trail); constraint->RegisterWith(model->GetOrCreate()); model->TakeOwnership(constraint); }; diff --git a/ortools/sat/linear_constraint.cc b/ortools/sat/linear_constraint.cc index d274883714..a8200e00da 100644 --- a/ortools/sat/linear_constraint.cc +++ b/ortools/sat/linear_constraint.cc @@ -69,13 +69,13 @@ void LinearConstraintBuilder::AddLinearExpression(const LinearExpression& expr, void LinearConstraintBuilder::AddQuadraticLowerBound( AffineExpression left, AffineExpression right, IntegerTrail* integer_trail) { - if (left.IsFixed(integer_trail)) { - AddTerm(right, left.Min(integer_trail)); - } else if (right.IsFixed(integer_trail)) { - AddTerm(left, right.Min(integer_trail)); + if (integer_trail->IsFixed(left)) { + AddTerm(right, integer_trail->LowerBound(left)); + } else if (integer_trail->IsFixed(right)) { + AddTerm(left, integer_trail->LowerBound(right)); } else { - const IntegerValue left_min = left.Min(integer_trail); - const IntegerValue right_min = right.Min(integer_trail); + const IntegerValue left_min = integer_trail->LowerBound(left); + const IntegerValue right_min = integer_trail->LowerBound(right); AddTerm(left, right_min); AddTerm(right, left_min); // Substract the energy counted twice. diff --git a/ortools/sat/scheduling_cuts.cc b/ortools/sat/scheduling_cuts.cc index 4fe2367a6f..68fe5ad8ba 100644 --- a/ortools/sat/scheduling_cuts.cc +++ b/ortools/sat/scheduling_cuts.cc @@ -361,11 +361,11 @@ CutGenerator CreateCumulativeEnergyCutGenerator( IntegerTrail* integer_trail = model->GetOrCreate(); for (const AffineExpression& demand_expr : demands) { - if (!demand_expr.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(demand_expr)) { result.vars.push_back(demand_expr.var); } } - if (!capacity.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(capacity)) { result.vars.push_back(capacity.var); } AddIntegerVariableFromIntervals(helper, model, &result.vars); @@ -425,11 +425,11 @@ CutGenerator CreateCumulativeTimeTableCutGenerator( IntegerTrail* integer_trail = model->GetOrCreate(); for (const AffineExpression& demand_expr : demands) { - if (!demand_expr.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(demand_expr)) { result.vars.push_back(demand_expr.var); } } - if (!capacity.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(capacity)) { result.vars.push_back(capacity.var); } AddIntegerVariableFromIntervals(helper, model, &result.vars); @@ -606,11 +606,11 @@ CutGenerator CreateCumulativePrecedenceCutGenerator( IntegerTrail* integer_trail = model->GetOrCreate(); for (const AffineExpression& demand_expr : demands) { - if (!demand_expr.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(demand_expr)) { result.vars.push_back(demand_expr.var); } } - if (!capacity.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(capacity)) { result.vars.push_back(capacity.var); } AddIntegerVariableFromIntervals(helper, model, &result.vars); @@ -925,11 +925,11 @@ CutGenerator CreateCumulativeCompletionTimeCutGenerator( IntegerTrail* integer_trail = model->GetOrCreate(); for (const AffineExpression& demand_expr : demands) { - if (!demand_expr.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(demand_expr)) { result.vars.push_back(demand_expr.var); } } - if (!capacity.IsFixed(integer_trail)) { + if (!integer_trail->IsFixed(capacity)) { result.vars.push_back(capacity.var); } AddIntegerVariableFromIntervals(helper, model, &result.vars);