diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 1a950abbfe..35036f24d7 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -752,7 +752,8 @@ void LoadNoOverlap2dConstraint(const ConstraintProto& ct, Model* m) { mapping->Intervals(ct.no_overlap_2d().x_intervals()); const std::vector y_intervals = mapping->Intervals(ct.no_overlap_2d().y_intervals()); - m->Add(StrictNonOverlappingRectangles(x_intervals, y_intervals)); + m->Add( + NonOverlappingRectangles(x_intervals, y_intervals, /*is_strict=*/true)); } void LoadCumulativeConstraint(const ConstraintProto& ct, Model* m) { diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 265c669852..3283db9143 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -140,31 +140,42 @@ bool PresolveContext::DomainContains(int ref, int64 value) const { return domains[ref].Contains(value); } -bool PresolveContext::IntersectDomainWith(int ref, const Domain& domain) { +ABSL_MUST_USE_RESULT bool PresolveContext::IntersectDomainWith( + int ref, const Domain& domain, bool* domain_modified) { DCHECK(!DomainIsEmpty(ref)); const int var = PositiveRef(ref); if (RefIsPositive(ref)) { - if (domains[var].IsIncludedIn(domain)) return false; + if (domains[var].IsIncludedIn(domain)) { + return true; + } domains[var] = domains[var].IntersectionWith(domain); } else { const Domain temp = domain.Negation(); - if (domains[var].IsIncludedIn(temp)) return false; + if (domains[var].IsIncludedIn(temp)) { + return true; + } domains[var] = domains[var].IntersectionWith(temp); } + if (domain_modified != nullptr) { + *domain_modified = true; + } modified_domains.Set(var); - if (domains[var].IsEmpty()) is_unsat = true; + if (domains[var].IsEmpty()) { + is_unsat = true; + return false; + } return true; } -void PresolveContext::SetLiteralToFalse(int lit) { +ABSL_MUST_USE_RESULT bool PresolveContext::SetLiteralToFalse(int lit) { const int var = PositiveRef(lit); const int64 value = RefIsPositive(lit) ? 0ll : 1ll; - IntersectDomainWith(var, Domain(value)); + return IntersectDomainWith(var, Domain(value)); } -void PresolveContext::SetLiteralToTrue(int lit) { +ABSL_MUST_USE_RESULT bool PresolveContext::SetLiteralToTrue(int lit) { return SetLiteralToFalse(NegatedRef(lit)); } @@ -204,7 +215,7 @@ void PresolveContext::UpdateNewConstraintsVariableUsage() { } bool PresolveContext::ConstraintVariableUsageIsConsistent() { - if (is_unsat) return true; + if (is_unsat) return false; if (constraint_to_vars.size() != working_model->constraints_size()) { LOG(INFO) << "Wrong constraint_to_vars size!"; return false; @@ -238,6 +249,7 @@ void PresolveContext::StoreAffineRelation(const ConstraintProto& ct, int ref_x, int64 offset) { int x = PositiveRef(ref_x); int y = PositiveRef(ref_y); + if (is_unsat) return; if (IsFixed(x) || IsFixed(y)) return; int64 c = RefIsPositive(ref_x) == RefIsPositive(ref_y) ? coeff : -coeff; @@ -338,7 +350,12 @@ AffineRelation::Relation PresolveContext::GetAffineRelation(int ref) { // Create the internal structure for any new variables in working_model. void PresolveContext::InitializeNewDomains() { for (int i = domains.size(); i < working_model->variables_size(); ++i) { - domains.push_back(ReadDomainFromProto(working_model->variables(i))); + Domain domain = ReadDomainFromProto(working_model->variables(i)); + if (domain.IsEmpty()) { + is_unsat = true; + return; + } + domains.push_back(domain); if (IsFixed(i)) ExploitFixedDomain(i); } modified_domains.Resize(domains.size()); @@ -398,6 +415,12 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) { namespace { +ABSL_MUST_USE_RESULT bool RemoveConstraint(ConstraintProto* ct, + PresolveContext* context) { + ct->Clear(); + return true; +} + // ============================================================================= // Presolve functions. // @@ -406,22 +429,23 @@ namespace { // // TODO(user): it migth be better to simply move all these functions to the // PresolveContext class. +// +// Invariant about UNSAT: All these functions should abort right away if +// context->IsUnsat() is true. And the only way to change the status to unsat is +// through ABSL_MUST_USE_RESULT function that should also abort right away the +// current code. This way we shouldn't keep doing computation on an inconsistent +// state. // ============================================================================= -ABSL_MUST_USE_RESULT bool RemoveConstraint(ConstraintProto* ct, - PresolveContext* context) { - ct->Clear(); - return true; -} - bool PresolveEnforcementLiteral(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (!HasEnforcementLiteral(*ct)) return false; int new_size = 0; const int old_size = ct->enforcement_literal().size(); for (const int literal : ct->enforcement_literal()) { - // Remove true literal. if (context->LiteralIsTrue(literal)) { + // We can remove a literal at true. context->UpdateRuleStats("true enforcement literal"); continue; } @@ -429,10 +453,12 @@ bool PresolveEnforcementLiteral(ConstraintProto* ct, PresolveContext* context) { if (context->LiteralIsFalse(literal)) { context->UpdateRuleStats("false enforcement literal"); return RemoveConstraint(ct, context); - } else if (context->VariableIsUniqueAndRemovable(literal)) { + } + + if (context->VariableIsUniqueAndRemovable(literal)) { // We can simply set it to false and ignore the constraint in this case. context->UpdateRuleStats("enforcement literal not used"); - context->SetLiteralToFalse(literal); + CHECK(context->SetLiteralToFalse(literal)); return RemoveConstraint(ct, context); } @@ -443,7 +469,9 @@ bool PresolveEnforcementLiteral(ConstraintProto* ct, PresolveContext* context) { } bool PresolveBoolXor(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (HasEnforcementLiteral(*ct)) return false; + int new_size = 0; bool changed = false; int num_true_literals = 0; @@ -489,6 +517,8 @@ bool PresolveBoolXor(ConstraintProto* ct, PresolveContext* context) { } bool PresolveBoolOr(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + // Move the enforcement literal inside the clause if any. Note that we do not // mark this as a change since the literal in the constraint are the same. if (HasEnforcementLiteral(*ct)) { @@ -517,7 +547,7 @@ bool PresolveBoolOr(ConstraintProto* ct, PresolveContext* context) { // objective var usage by 1). if (context->VariableIsUniqueAndRemovable(literal)) { context->UpdateRuleStats("bool_or: singleton"); - context->SetLiteralToTrue(literal); + if (!context->SetLiteralToTrue(literal)) return true; return RemoveConstraint(ct, context); } if (context->tmp_literal_set.contains(NegatedRef(literal))) { @@ -534,12 +564,11 @@ bool PresolveBoolOr(ConstraintProto* ct, PresolveContext* context) { if (context->tmp_literals.empty()) { context->UpdateRuleStats("bool_or: empty"); - context->is_unsat = true; - return true; + return context->NotifyThatModelIsUnsat(); } if (context->tmp_literals.size() == 1) { context->UpdateRuleStats("bool_or: only one literal"); - context->SetLiteralToTrue(context->tmp_literals[0]); + if (!context->SetLiteralToTrue(context->tmp_literals[0])) return true; return RemoveConstraint(ct, context); } if (context->tmp_literals.size() == 2) { @@ -573,21 +602,17 @@ ABSL_MUST_USE_RESULT bool MarkConstraintAsFalse(ConstraintProto* ct, PresolveBoolOr(ct, context); return true; } else { - context->is_unsat = true; - return RemoveConstraint(ct, context); + return context->NotifyThatModelIsUnsat(); } } bool PresolveBoolAnd(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + if (!HasEnforcementLiteral(*ct)) { context->UpdateRuleStats("bool_and: non-reified."); for (const int literal : ct->bool_and().literals()) { - if (context->LiteralIsFalse(literal)) { - context->is_unsat = true; - return true; - } else { - context->SetLiteralToTrue(literal); - } + if (!context->SetLiteralToTrue(literal)) return true; } return RemoveConstraint(ct, context); } @@ -605,7 +630,7 @@ bool PresolveBoolAnd(ConstraintProto* ct, PresolveContext* context) { } if (context->VariableIsUniqueAndRemovable(literal)) { changed = true; - context->SetLiteralToTrue(literal); + if (!context->SetLiteralToTrue(literal)) return true; continue; } context->tmp_literals.push_back(literal); @@ -627,6 +652,7 @@ bool PresolveBoolAnd(ConstraintProto* ct, PresolveContext* context) { } bool PresolveAtMostOne(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; CHECK(!HasEnforcementLiteral(*ct)); // Fix to false any duplicate literals. @@ -635,7 +661,7 @@ bool PresolveAtMostOne(ConstraintProto* ct, PresolveContext* context) { int previous = kint32max; for (const int literal : ct->at_most_one().literals()) { if (literal == previous) { - context->SetLiteralToFalse(literal); + if (!context->SetLiteralToFalse(literal)) return true; context->UpdateRuleStats("at_most_one: duplicate literals"); } previous = literal; @@ -647,7 +673,9 @@ bool PresolveAtMostOne(ConstraintProto* ct, PresolveContext* context) { if (context->LiteralIsTrue(literal)) { context->UpdateRuleStats("at_most_one: satisfied"); for (const int other : ct->at_most_one().literals()) { - if (other != literal) context->SetLiteralToFalse(other); + if (other != literal) { + if (!context->SetLiteralToFalse(other)) return true; + } } return RemoveConstraint(ct, context); } @@ -675,6 +703,7 @@ bool PresolveAtMostOne(ConstraintProto* ct, PresolveContext* context) { } bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (ct->int_max().vars().empty()) { return MarkConstraintAsFalse(ct, context); } @@ -727,8 +756,10 @@ bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) { infered_domain = infered_domain.UnionWith( context->DomainOf(ref).IntersectionWith({target_min, target_max})); } - domain_reduced |= context->IntersectDomainWith(target_ref, infered_domain); - if (context->is_unsat) return true; + if (!context->IntersectDomainWith(target_ref, infered_domain, + &domain_reduced)) { + return true; + } } // Pass 2, update the argument domains. Filter them eventually. @@ -737,8 +768,10 @@ bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) { target_max = context->MaxOf(target_ref); for (const int ref : ct->int_max().vars()) { if (!HasEnforcementLiteral(*ct)) { - domain_reduced |= - context->IntersectDomainWith(ref, Domain(kint64min, target_max)); + if (!context->IntersectDomainWith(ref, Domain(kint64min, target_max), + &domain_reduced)) { + return true; + } } if (context->MaxOf(ref) >= target_min) { ct->mutable_int_max()->set_vars(new_size++, ref); @@ -777,6 +810,8 @@ bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) { } bool PresolveIntMin(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + const auto copy = ct->int_min(); ct->mutable_int_max()->set_target(NegatedRef(copy.target())); for (const int ref : copy.vars()) { @@ -786,6 +821,7 @@ bool PresolveIntMin(ConstraintProto* ct, PresolveContext* context) { } bool PresolveIntProd(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (HasEnforcementLiteral(*ct)) return false; if (ct->int_prod().vars_size() == 2) { @@ -807,7 +843,11 @@ bool PresolveIntProd(ConstraintProto* ct, PresolveContext* context) { context->UpdateRuleStats("int_prod: linearize product by constant."); return RemoveConstraint(ct, context); } else if (context->MinOf(a) != 1) { - context->IntersectDomainWith(product, Domain(0, 0)); + bool domain_modified = false; + if (!context->IntersectDomainWith(product, Domain(0, 0), + &domain_modified)) { + return false; + } context->UpdateRuleStats("int_prod: fix variable to zero."); return RemoveConstraint(ct, context); } else { @@ -815,7 +855,9 @@ bool PresolveIntProd(ConstraintProto* ct, PresolveContext* context) { return RemoveConstraint(ct, context); } } else if (a == b && a == product) { // x = x * x, only true for {0, 1}. - context->IntersectDomainWith(product, Domain(0, 1)); + if (!context->IntersectDomainWith(product, Domain(0, 1))) { + return false; + } context->UpdateRuleStats("int_prod: fix variable to zero or one."); return RemoveConstraint(ct, context); } @@ -831,7 +873,9 @@ bool PresolveIntProd(ConstraintProto* ct, PresolveContext* context) { } // This is a bool constraint! - context->IntersectDomainWith(target_ref, Domain(0, 1)); + if (!context->IntersectDomainWith(target_ref, Domain(0, 1))) { + return false; + } context->UpdateRuleStats("int_prod: all Boolean."); { ConstraintProto* new_ct = context->working_model->add_constraints(); @@ -853,12 +897,15 @@ bool PresolveIntProd(ConstraintProto* ct, PresolveContext* context) { } bool PresolveIntDiv(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + // For now, we only presolve the case where the divisor is constant. const int target = ct->int_div().target(); const int ref_x = ct->int_div().vars(0); const int ref_div = ct->int_div().vars(1); if (!RefIsPositive(target) || !RefIsPositive(ref_x) || - !RefIsPositive(ref_div) || !context->IsFixed(ref_div)) { + !RefIsPositive(ref_div) || context->DomainIsEmpty(ref_div) || + !context->IsFixed(ref_div)) { return false; } @@ -866,10 +913,17 @@ bool PresolveIntDiv(ConstraintProto* ct, PresolveContext* context) { if (divisor == 1) { context->UpdateRuleStats("TODO int_div: rewrite to equality"); } - if (context->IntersectDomainWith( - target, context->DomainOf(ref_x).DivisionBy(divisor))) { - context->UpdateRuleStats( - "int_div: updated domain of target in target = X / cte"); + bool domain_modified = false; + if (context->IntersectDomainWith(target, + context->DomainOf(ref_x).DivisionBy(divisor), + &domain_modified)) { + if (domain_modified) { + context->UpdateRuleStats( + "int_div: updated domain of target in target = X / cte"); + } + } else { + // Model is unsat. + return false; } // TODO(user): reduce the domain of X by introducing an @@ -924,6 +978,8 @@ bool ExploitEquivalenceRelations(ConstraintProto* ct, } void DivideLinearByGcd(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return; + // Compute the GCD of all coefficients. int64 gcd = 0; const int num_vars = ct->linear().vars().size(); @@ -943,6 +999,8 @@ void DivideLinearByGcd(ConstraintProto* ct, PresolveContext* context) { } bool CanonicalizeLinear(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + // First regroup the terms on the same variables and sum the fixed ones. // // TODO(user): move terms in context to reuse its memory? Add a quick pass @@ -1021,6 +1079,8 @@ bool CanonicalizeLinear(ConstraintProto* ct, PresolveContext* context) { } bool RemoveSingletonInLinear(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + const bool was_affine = gtl::ContainsKey(context->affine_constraints, ct); if (was_affine) return false; @@ -1074,6 +1134,8 @@ bool RemoveSingletonInLinear(ConstraintProto* ct, PresolveContext* context) { } bool PresolveLinear(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + Domain rhs = ReadDomainFromProto(ct->linear()); // Empty constraint? @@ -1094,10 +1156,14 @@ bool PresolveLinear(ConstraintProto* ct, PresolveContext* context) { context->UpdateRuleStats("linear: size one"); const int var = PositiveRef(ct->linear().vars(0)); if (coeff == 1) { - context->IntersectDomainWith(var, rhs); + if (!context->IntersectDomainWith(var, rhs)) { + return true; + } } else { DCHECK_EQ(coeff, -1); // Because of the GCD above. - context->IntersectDomainWith(var, rhs.Negation()); + if (!context->IntersectDomainWith(var, rhs.Negation())) { + return true; + } } return RemoveConstraint(ct, context); } @@ -1173,7 +1239,12 @@ bool PresolveLinear(ConstraintProto* ct, PresolveContext* context) { new_domain = left_domains[i] .AdditionWith(right_domain) .InverseMultiplicationBy(-ct->linear().coeffs(i)); - if (context->IntersectDomainWith(ct->linear().vars(i), new_domain)) { + bool domain_modified = false; + if (!context->IntersectDomainWith(ct->linear().vars(i), new_domain, + &domain_modified)) { + return true; + } + if (domain_modified) { new_bounds = true; } } @@ -1218,6 +1289,8 @@ bool PresolveLinear(ConstraintProto* ct, PresolveContext* context) { // This operation is similar to coefficient strengthening in the MIP world. void ExtractEnforcementLiteralFromLinearConstraint(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return; + const LinearConstraintProto& arg = ct->linear(); const int num_vars = arg.vars_size(); int64 min_sum = 0; @@ -1309,6 +1382,7 @@ void ExtractEnforcementLiteralFromLinearConstraint(ConstraintProto* ct, } void ExtractAtMostOneFromLinear(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return; if (HasEnforcementLiteral(*ct)) return; const Domain domain = ReadDomainFromProto(ct->linear()); @@ -1362,6 +1436,8 @@ void ExtractAtMostOneFromLinear(ConstraintProto* ct, PresolveContext* context) { // Convert some linear constraint involving only Booleans to their Boolean // form. bool PresolveLinearOnBooleans(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + // TODO(user): the alternative to mark any newly created constraints might // be better. if (gtl::ContainsKey(context->affine_constraints, ct)) return false; @@ -1553,6 +1629,8 @@ bool PresolveLinearOnBooleans(ConstraintProto* ct, PresolveContext* context) { } bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + const int start = ct->interval().start(); const int end = ct->interval().end(); const int size = ct->interval().size(); @@ -1575,14 +1653,23 @@ bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) { if (!ct->enforcement_literal().empty()) return false; bool changed = false; - changed |= context->IntersectDomainWith( - end, context->DomainOf(start).AdditionWith(context->DomainOf(size))); - changed |= context->IntersectDomainWith( - start, - context->DomainOf(end).AdditionWith(context->DomainOf(size).Negation())); - changed |= context->IntersectDomainWith( - size, - context->DomainOf(end).AdditionWith(context->DomainOf(start).Negation())); + if (!context->IntersectDomainWith( + end, context->DomainOf(start).AdditionWith(context->DomainOf(size)), + &changed)) { + return false; + } + if (!context->IntersectDomainWith(start, + context->DomainOf(end).AdditionWith( + context->DomainOf(size).Negation()), + &changed)) { + return false; + } + if (!context->IntersectDomainWith(size, + context->DomainOf(end).AdditionWith( + context->DomainOf(start).Negation()), + &changed)) { + return false; + } if (changed) { context->UpdateRuleStats("interval: reduced domains"); } @@ -1603,6 +1690,8 @@ bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) { } bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + const int index_ref = ct->element().index(); const int target_ref = ct->element().target(); @@ -1616,9 +1705,10 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { { bool reduced_index_domain = false; - if (context->IntersectDomainWith( - index_ref, Domain(0, ct->element().vars_size() - 1))) { - reduced_index_domain = true; + if (!context->IntersectDomainWith(index_ref, + Domain(0, ct->element().vars_size() - 1), + &reduced_index_domain)) { + return false; } // Filter possible index values. Accumulate variable domains to build @@ -1633,7 +1723,11 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { const int ref = ct->element().vars(value); const Domain domain = context->DomainOf(ref); if (domain.IntersectionWith(target_domain).IsEmpty()) { - context->IntersectDomainWith(index_ref, Domain(value).Complement()); + bool domain_modified = false; + if (!context->IntersectDomainWith( + index_ref, Domain(value).Complement(), &domain_modified)) { + return false; + } reduced_index_domain = true; } else { ++num_vars; @@ -1652,8 +1746,12 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { if (reduced_index_domain) { context->UpdateRuleStats("element: reduced index domain"); } - if (context->IntersectDomainWith(target_ref, infered_domain)) { - if (context->DomainIsEmpty(target_ref)) return true; + bool domain_modified = false; + if (!context->IntersectDomainWith(target_ref, infered_domain, + &domain_modified)) { + return true; + } + if (domain_modified) { context->UpdateRuleStats("element: reduced target domain"); } } @@ -1793,7 +1891,9 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } if (index_domain.Size() > valid_index_values.size()) { const Domain new_domain = Domain::FromValues(valid_index_values); - context->IntersectDomainWith(index_ref, new_domain); + if (!context->IntersectDomainWith(index_ref, new_domain)) { + return true; + } context->UpdateRuleStats( "CHECK element: reduce index domain from affine target"); } @@ -1835,8 +1935,10 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } } if (possible_indices.size() < index_domain.Size()) { - context->IntersectDomainWith(index_ref, - Domain::FromValues(possible_indices)); + if (!context->IntersectDomainWith(index_ref, + Domain::FromValues(possible_indices))) { + return true; + } } context->UpdateRuleStats( "element: reduce index domain when target equals index"); @@ -1906,6 +2008,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (HasEnforcementLiteral(*ct)) return false; if (ct->table().negated()) return false; if (ct->table().vars().empty()) { @@ -1993,10 +2096,13 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { bool changed = false; for (int j = 0; j < num_vars; ++j) { const int ref = ct->table().vars(j); - changed |= context->IntersectDomainWith( - PositiveRef(ref), Domain::FromValues(std::vector( - new_domains[j].begin(), new_domains[j].end()))); - if (context->is_unsat) return true; + if (!context->IntersectDomainWith( + PositiveRef(ref), + Domain::FromValues(std::vector(new_domains[j].begin(), + new_domains[j].end())), + &changed)) { + return true; + } } if (changed) { context->UpdateRuleStats("table: reduced variable domains"); @@ -2052,7 +2158,9 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { } bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (HasEnforcementLiteral(*ct)) return false; + AllDifferentConstraintProto& all_diff = *ct->mutable_all_diff(); const int size = all_diff.vars_size(); @@ -2077,9 +2185,10 @@ bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) { for (int j = 0; j < size; ++j) { if (i == j) continue; if (context->DomainContains(all_diff.vars(j), value)) { - context->IntersectDomainWith(all_diff.vars(j), - Domain(value).Complement()); - if (context->is_unsat) return true; + if (!context->IntersectDomainWith(all_diff.vars(j), + Domain(value).Complement())) { + return true; + } propagated = true; } } @@ -2101,6 +2210,8 @@ bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) { } bool PresolveNoOverlap(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + const NoOverlapConstraintProto& proto = ct->no_overlap(); // Filter absent intervals. @@ -2127,6 +2238,8 @@ bool PresolveNoOverlap(ConstraintProto* ct, PresolveContext* context) { } bool PresolveCumulative(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; + const CumulativeConstraintProto& proto = ct->cumulative(); // Filter absent intervals. @@ -2184,17 +2297,19 @@ bool PresolveCumulative(ConstraintProto* ct, PresolveContext* context) { if (demand_min > capacity) { context->UpdateRuleStats("cumulative: demand_min exceeds capacity"); if (ct.enforcement_literal().empty()) { - context->is_unsat = true; - return changed; + return context->NotifyThatModelIsUnsat(); } else { CHECK_EQ(ct.enforcement_literal().size(), 1); - context->SetLiteralToFalse(ct.enforcement_literal(0)); + if (!context->SetLiteralToFalse(ct.enforcement_literal(0))) return true; } return changed; } else if (demand_max > capacity) { if (ct.enforcement_literal().empty()) { context->UpdateRuleStats("cumulative: demand_max exceeds capacity."); - context->IntersectDomainWith(demand_ref, Domain(kint64min, capacity)); + if (!context->IntersectDomainWith(demand_ref, + Domain(kint64min, capacity))) { + return true; + } } else { // TODO(user): we abort because we cannot convert this to a no_overlap // for instance. @@ -2229,6 +2344,7 @@ bool PresolveCumulative(ConstraintProto* ct, PresolveContext* context) { } bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (HasEnforcementLiteral(*ct)) return false; CircuitConstraintProto& proto = *ct->mutable_circuit(); @@ -2257,8 +2373,7 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { if (refs.size() == 1) { if (!context->LiteralIsTrue(refs.front())) { ++num_fixed_at_true; - context->SetLiteralToTrue(refs.front()); - if (context->is_unsat) return false; + if (!context->SetLiteralToTrue(refs.front())) return true; } continue; } @@ -2276,7 +2391,7 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { if (num_true > 0) { for (const int ref : refs) { if (ref != true_ref) { - context->SetLiteralToFalse(ref); + if (!context->SetLiteralToFalse(ref)) return true; } } } @@ -2298,8 +2413,7 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { if (context->LiteralIsFalse(ref)) continue; if (context->LiteralIsTrue(ref)) { if (next[proto.tails(i)] != -1) { - context->is_unsat = true; - return true; + return context->NotifyThatModelIsUnsat(); } next[proto.tails(i)] = proto.heads(i); if (proto.tails(i) != proto.heads(i)) { @@ -2326,8 +2440,7 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { if (incoming_arcs[i].empty() && outgoing_arcs[i].empty()) continue; if (new_in_degree[i] == 0 || new_out_degree[i] == 0) { - context->is_unsat = true; - return true; + return context->NotifyThatModelIsUnsat(); } } @@ -2345,9 +2458,9 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { for (int i = 0; i < num_arcs; ++i) { if (visited[proto.tails(i)]) continue; if (proto.tails(i) == proto.heads(i)) { - context->SetLiteralToTrue(proto.literals(i)); + if (!context->SetLiteralToTrue(proto.literals(i))) return true; } else { - context->SetLiteralToFalse(proto.literals(i)); + if (!context->SetLiteralToFalse(proto.literals(i))) return true; } } context->UpdateRuleStats("circuit: fully specified."); @@ -2395,6 +2508,7 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { } bool PresolveAutomaton(ConstraintProto* ct, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; if (HasEnforcementLiteral(*ct)) return false; AutomatonConstraintProto& proto = *ct->mutable_automaton(); if (proto.vars_size() == 0 || proto.transition_label_size() == 0) { @@ -2516,9 +2630,13 @@ bool PresolveAutomaton(ConstraintProto* ct, PresolveContext* context) { bool removed_values = false; for (int time = 0; time < n; ++time) { - removed_values |= context->IntersectDomainWith( - vars[time], Domain::FromValues({reached_values[time].begin(), - reached_values[time].end()})); + if (!context->IntersectDomainWith( + vars[time], + Domain::FromValues( + {reached_values[time].begin(), reached_values[time].end()}), + &removed_values)) { + return false; + } } if (removed_values) { context->UpdateRuleStats("automaton: reduced variable domains"); @@ -2619,7 +2737,7 @@ void ExtractClauses(const ClauseContainer& container, CpModelProto* proto) { } void Probe(TimeLimit* global_time_limit, PresolveContext* context) { - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; // Update the domain in the current CpModelProto. for (int i = 0; i < context->working_model->variables_size(); ++i) { @@ -2648,14 +2766,12 @@ void Probe(TimeLimit* global_time_limit, PresolveContext* context) { if (mapping->ConstraintIsAlreadyLoaded(&ct)) continue; CHECK(LoadConstraint(ct, &model)); if (sat_solver->IsModelUnsat()) { - context->is_unsat = true; - return; + return (void)context->NotifyThatModelIsUnsat(); } } encoder->AddAllImplicationsBetweenAssociatedLiterals(); if (!sat_solver->Propagate()) { - context->is_unsat = true; - return; + return (void)context->NotifyThatModelIsUnsat(); } // Probe. @@ -2665,8 +2781,7 @@ void Probe(TimeLimit* global_time_limit, PresolveContext* context) { auto* implication_graph = model.GetOrCreate(); ProbeBooleanVariables(/*deterministic_time_limit=*/1.0, &model); if (sat_solver->IsModelUnsat() || !implication_graph->DetectEquivalences()) { - context->is_unsat = true; - return; + return (void)context->NotifyThatModelIsUnsat(); } // Update the presolve context with fixed Boolean variables. @@ -2675,7 +2790,7 @@ void Probe(TimeLimit* global_time_limit, PresolveContext* context) { const int var = mapping->GetProtoVariableFromBooleanVariable(l.Variable()); if (var >= 0) { const int ref = l.IsPositive() ? var : NegatedRef(var); - context->SetLiteralToTrue(ref); + if (!context->SetLiteralToTrue(ref)) return; } } @@ -2687,7 +2802,9 @@ void Probe(TimeLimit* global_time_limit, PresolveContext* context) { if (!mapping->IsBoolean(var)) { const Domain new_domain = integer_trail->InitialVariableDomain(mapping->Integer(var)); - context->IntersectDomainWith(var, new_domain); + if (!context->IntersectDomainWith(var, new_domain)) { + return; + } continue; } @@ -2707,7 +2824,7 @@ void Probe(TimeLimit* global_time_limit, PresolveContext* context) { void PresolvePureSatPart(PresolveContext* context) { // TODO(user,user): Reenable some SAT presolve with // enumerate_all_solutions set to true. - if (context->is_unsat || context->enumerate_all_solutions) return; + if (context->ModelIsUnsat() || context->enumerate_all_solutions) return; const int num_variables = context->working_model->variables_size(); SatPostsolver postsolver(num_variables); @@ -2808,8 +2925,7 @@ void PresolvePureSatPart(PresolveContext* context) { const int old_num_clause = postsolver.NumClauses(); if (!presolver.Presolve(can_be_removed)) { VLOG(1) << "UNSAT during SAT presolve."; - context->is_unsat = true; - return; + return (void)context->NotifyThatModelIsUnsat(); } if (old_num_clause == postsolver.NumClauses()) break; } @@ -2856,7 +2972,7 @@ void MaybeDivideByGcd(std::map* objective_map, int64* divisor) { // effect. Like on a triangular matrix where each expansion reduced the size // of the objective by one. Investigate and fix? void ExpandObjective(PresolveContext* context) { - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; // Convert the objective linear expression to a map for ease of use below. // We also only use affine representative. @@ -3110,7 +3226,7 @@ void ExpandObjective(PresolveContext* context) { } void MergeNoOverlapConstraints(PresolveContext* context) { - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; const int num_constraints = context->working_model->constraints_size(); int old_num_no_overlaps = 0; @@ -3173,7 +3289,7 @@ void MergeNoOverlapConstraints(PresolveContext* context) { // Extracts cliques from bool_and and small at_most_one constraints and // transforms them into maximal cliques. void TransformIntoMaxCliques(PresolveContext* context) { - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; auto convert = [](int ref) { if (RefIsPositive(ref)) return Literal(BooleanVariable(ref), true); @@ -3219,8 +3335,7 @@ void TransformIntoMaxCliques(PresolveContext* context) { if (clique.size() <= 100) graph->AddAtMostOne(clique); } if (!graph->DetectEquivalences()) { - context->is_unsat = true; - return; + return (void)context->NotifyThatModelIsUnsat(); } graph->TransformIntoMaxCliques(&cliques); @@ -3256,6 +3371,7 @@ void TransformIntoMaxCliques(PresolveContext* context) { } bool PresolveOneConstraint(int c, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; ConstraintProto* ct = context->working_model->mutable_constraints(c); // Generic presolve to exploit variable/literal equivalence. @@ -3296,7 +3412,6 @@ bool PresolveOneConstraint(int c, PresolveContext* context) { if (PresolveLinear(ct, context)) { context->UpdateConstraintVariableUsage(c); } - if (context->is_unsat) return false; if (ct->constraint_case() == ConstraintProto::ConstraintCase::kLinear) { const int old_num_enforcement_literals = ct->enforcement_literal_size(); @@ -3305,10 +3420,8 @@ bool PresolveOneConstraint(int c, PresolveContext* context) { if (PresolveLinear(ct, context)) { context->UpdateConstraintVariableUsage(c); } - if (context->is_unsat) return false; } } - if (ct->constraint_case() == ConstraintProto::ConstraintCase::kLinear) { return PresolveLinearOnBooleans(ct, context); } @@ -3360,6 +3473,7 @@ bool ProcessSetPPCSubset(int c1, int c2, const std::vector& c2_minus_c1, const std::vector& original_constraint_index, std::vector* marked_for_removal, PresolveContext* context) { + if (context->ModelIsUnsat()) return false; CHECK(!(*marked_for_removal)[c1]); CHECK(!(*marked_for_removal)[c2]); ConstraintProto* ct1 = context->working_model->mutable_constraints( @@ -3370,7 +3484,7 @@ bool ProcessSetPPCSubset(int c1, int c2, const std::vector& c2_minus_c1, ct2->constraint_case() == ConstraintProto::ConstraintCase::kAtMostOne) { // fix extras in c2 to 0 for (const int literal : c2_minus_c1) { - context->SetLiteralToFalse(literal); + if (!context->SetLiteralToFalse(literal)) return true; context->UpdateRuleStats("setppc: fixed variables"); } return true; @@ -3434,7 +3548,7 @@ bool ProcessSetPPC(PresolveContext* context, TimeLimit* time_limit) { if (PresolveOneConstraint(c, context)) { context->UpdateConstraintVariableUsage(c); } - if (context->is_unsat) return false; + if (context->ModelIsUnsat()) return false; } if (ct->constraint_case() == ConstraintProto::ConstraintCase::kBoolOr || ct->constraint_case() == ConstraintProto::ConstraintCase::kAtMostOne) { @@ -3549,7 +3663,7 @@ bool ProcessSetPPC(PresolveContext* context, TimeLimit* time_limit) { } void TryToSimplifyDomains(PresolveContext* context) { - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; const int num_vars = context->working_model->variables_size(); for (int var = 0; var < num_vars; ++var) { @@ -3609,6 +3723,7 @@ void TryToSimplifyDomains(PresolveContext* context) { FillDomainInProto(scaled_domain, var_proto); } context->InitializeNewDomains(); + if (context->ModelIsUnsat()) return; ConstraintProto* const ct = context->working_model->add_constraints(); LinearConstraintProto* const lin = ct->mutable_linear(); @@ -3627,7 +3742,7 @@ void TryToSimplifyDomains(PresolveContext* context) { } void PresolveToFixPoint(PresolveContext* context, TimeLimit* time_limit) { - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; // This is used for constraint having unique variables in them (i.e. not // appearing anywhere else) to not call the presolve more than once for this @@ -3638,9 +3753,9 @@ void PresolveToFixPoint(PresolveContext* context, TimeLimit* time_limit) { std::vector in_queue(context->working_model->constraints_size(), true); std::deque queue(context->working_model->constraints_size()); std::iota(queue.begin(), queue.end(), 0); - while (!queue.empty() && !context->is_unsat) { + while (!queue.empty() && !context->ModelIsUnsat()) { if (time_limit != nullptr && time_limit->LimitReached()) break; - while (!queue.empty() && !context->is_unsat) { + while (!queue.empty() && !context->ModelIsUnsat()) { if (time_limit != nullptr && time_limit->LimitReached()) break; const int c = queue.front(); in_queue[c] = false; @@ -3697,13 +3812,9 @@ void PresolveToFixPoint(PresolveContext* context, TimeLimit* time_limit) { // TODO(user): Avoid reprocessing the constraints that changed the variables // with the use of timestamp. const int old_queue_size = queue.size(); + if (context->ModelIsUnsat()) return; for (const int v : context->modified_domains.PositionsSetAtLeastOnce()) { - if (context->DomainIsEmpty(v)) { - context->is_unsat = true; - break; - } if (context->IsFixed(v)) context->ExploitFixedDomain(v); - for (const int c : context->var_to_constraints[v]) { if (c >= 0 && !in_queue[c]) { in_queue[c] = true; @@ -3718,7 +3829,7 @@ void PresolveToFixPoint(PresolveContext* context, TimeLimit* time_limit) { context->modified_domains.SparseClearAll(); } - if (context->is_unsat) return; + if (context->ModelIsUnsat()) return; // Make sure we filter out absent intervals. // @@ -3750,7 +3861,7 @@ void PresolveToFixPoint(PresolveContext* context, TimeLimit* time_limit) { } void RemoveUnusedEquivalentVariables(PresolveContext* context) { - if (context->is_unsat || context->enumerate_all_solutions) return; + if (context->ModelIsUnsat() || context->enumerate_all_solutions) return; // Remove all affine constraints (they will be re-added later if // needed) in the presolved model. @@ -3785,7 +3896,12 @@ void RemoveUnusedEquivalentVariables(PresolveContext* context) { const Domain implied = context->DomainOf(var) .AdditionWith({-r.offset, -r.offset}) .InverseMultiplicationBy(r.coeff); - if (context->IntersectDomainWith(r.representative, implied)) { + bool domain_modified = false; + if (!context->IntersectDomainWith(r.representative, implied, + &domain_modified)) { + return; + } + if (domain_modified) { LOG(WARNING) << "Domain of " << r.representative << " was not fully propagated using the affine relation " << "(representative =" << r.representative @@ -3892,7 +4008,7 @@ bool PresolveCpModel(const PresolveOptions& options, // TODO(user): instead of extracting at most one, extra pairwise conflicts // and add them to bool_and clauses? this is some sort of small scale probing, // but good for sat presolve and clique later? - if (!context.is_unsat) { + if (!context.ModelIsUnsat()) { const int old_size = context.working_model->constraints_size(); for (int c = 0; c < old_size; ++c) { ConstraintProto* ct = context.working_model->mutable_constraints(c); @@ -3919,7 +4035,7 @@ bool PresolveCpModel(const PresolveOptions& options, PresolveToFixPoint(&context, options.time_limit); } - if (context.is_unsat) { + if (context.ModelIsUnsat()) { // Set presolved_model to the simplest UNSAT problem (empty clause). presolved_model->Clear(); presolved_model->add_constraints()->mutable_bool_or(); @@ -3985,6 +4101,7 @@ bool PresolveCpModel(const PresolveOptions& options, const int var = PositiveRef(ref); // Remove fixed variables. + if (context.ModelIsUnsat()) return true; if (context.IsFixed(var)) continue; // There is not point having a variable appear twice, so we only keep diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index ac02f47633..50505eb125 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -60,16 +60,23 @@ struct PresolveContext { // Returns true if this ref only appear in one constraint. bool VariableIsUniqueAndRemovable(int ref) const; - // Returns true iff the domain changed. - bool IntersectDomainWith(int ref, const Domain& domain); + // Returns false if the new domain is empty. Sets 'domain_modified' (if + // provided) to true iff the domain is modified otherwise does not change it. + ABSL_MUST_USE_RESULT bool IntersectDomainWith( + int ref, const Domain& domain, bool* domain_modified = nullptr); - // TODO(user): These function and IntersectDomainWith() can make the model - // UNSAT and leave the PresolveContext() in an unusable state. Either make - // sure that is not a problem until the client check for is_unsat, or use - // a construct like MUST_USE_RESULT to ensure that the client do not forget to - // test for empty domains. - void SetLiteralToFalse(int lit); - void SetLiteralToTrue(int lit); + // Returns false if the 'lit' doesn't have the desired value in the domain. + ABSL_MUST_USE_RESULT bool SetLiteralToFalse(int lit); + ABSL_MUST_USE_RESULT bool SetLiteralToTrue(int lit); + + // This function always return false. It is just a way to make a little bit + // more sure that we abort right away when infeasibility is detected. + ABSL_MUST_USE_RESULT bool NotifyThatModelIsUnsat() { + DCHECK(!is_unsat); + is_unsat = true; + return false; + } + bool ModelIsUnsat() const { return is_unsat; } // Stores a description of a rule that was just applied to have a summary of // what the presolve did at the end. @@ -151,9 +158,6 @@ struct PresolveContext { CpModelProto* working_model; CpModelProto* mapping_model; - // Initially false, and set to true on the first inconsistency. - bool is_unsat = false; - // Indicate if we are enumerating all solutions. This disable some presolve // rules. bool enumerate_all_solutions = false; @@ -173,6 +177,9 @@ struct PresolveContext { private: void AddVariableUsage(int c); + // Initially false, and set to true on the first inconsistency. + bool is_unsat = false; + // The current domain of each variables. std::vector domains; }; diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 8f4f89876b..a9758e8605 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1880,10 +1880,20 @@ CpSolverResponse SolvePureSatModel(const CpModelProto& model_proto, return response; } +void UpdateDomain(int64 new_lb, int64 new_ub, + IntegerVariableProto* mutable_var) { + const Domain old_domain = ReadDomainFromProto(*mutable_var); + const Domain new_domain = old_domain.IntersectionWith(Domain(new_lb, new_ub)); + CHECK(!new_domain.IsEmpty()) << "Invalid bounds."; + + FillDomainInProto(new_domain, mutable_var); +} + CpSolverResponse SolveCpModelWithLNS( const CpModelProto& model_proto, const std::function& observer, - int num_workers, int worker_id, WallTimer* wall_timer, Model* model) { + int num_workers, int worker_id, SharedBoundsManager* shared_bounds_manager, + WallTimer* wall_timer, Model* model) { SatParameters* parameters = model->GetOrCreate(); parameters->set_stop_after_first_solution(true); CpSolverResponse response; @@ -1891,19 +1901,18 @@ CpSolverResponse SolveCpModelWithLNS( if (synchro != nullptr && synchro->f != nullptr) { response = synchro->f(); } else { - response = SolveCpModelInternal( - model_proto, /*is_real_solve=*/true, observer, - /*shared_bounds_manager=*/nullptr, wall_timer, model); + response = + SolveCpModelInternal(model_proto, /*is_real_solve=*/true, observer, + shared_bounds_manager, wall_timer, model); } + CpModelProto mutable_model_proto = model_proto; if (response.status() != CpSolverStatus::FEASIBLE) { return response; } const bool focus_on_decision_variables = parameters->lns_focus_on_decision_variables(); - // TODO(user): Find a way to propagate the level zero bounds from the other - // worker inside this base LNS problem. - const NeighborhoodGeneratorHelper helper(&model_proto, + const NeighborhoodGeneratorHelper helper(&mutable_model_proto, focus_on_decision_variables); // For now we will just alternate between our possible neighborhoods. @@ -1973,6 +1982,31 @@ CpSolverResponse SolveCpModelWithLNS( response.objective_value() == response.best_objective_bound(); }, [&](int64 seed) { + // Update the bounds on mutable model proto. + if (shared_bounds_manager != nullptr) { + std::vector model_variables; + std::vector new_lower_bounds; + std::vector new_upper_bounds; + shared_bounds_manager->GetChangedBounds(worker_id, &model_variables, + &new_lower_bounds, + &new_upper_bounds); + + for (int i = 0; i < model_variables.size(); ++i) { + const int var = model_variables[i]; + const int64 new_lb = new_lower_bounds[i]; + const int64 new_ub = new_upper_bounds[i]; + if (VLOG_IS_ON(2)) { + const auto& domain = mutable_model_proto.variables(var).domain(); + const int64 old_lb = domain.Get(0); + const int64 old_ub = domain.Get(domain.size() - 1); + VLOG(2) << "Variable: " << var << " old domain: [" << old_lb + << ", " << old_ub << "] new domain: [" << new_lb << ", " + << new_ub << "]"; + } + UpdateDomain(new_lb, new_ub, + mutable_model_proto.mutable_variables(var)); + } + } AdaptiveParameterValue& difficulty = difficulties[seed % generators.size()]; const double saved_difficulty = difficulty.value(); @@ -2284,8 +2318,8 @@ CpSolverResponse SolveCpModelParallel( // TODO(user,user): Provide a better diversification for different // seeds. thread_response = SolveCpModelWithLNS( - model_proto, solution_observer, num_search_workers, - worker_id + random_seed, wall_timer, &local_model); + model_proto, solution_observer, num_search_workers, worker_id, + shared_bounds_manager.get(), wall_timer, &local_model); } else { thread_response = SolveCpModelInternal( model_proto, true, solution_observer, shared_bounds_manager.get(), @@ -2495,6 +2529,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { // seeds. const int random_seed = model->GetOrCreate()->random_seed(); response = SolveCpModelWithLNS(new_model, observer_function, 1, random_seed, + /*shared_bounds_manager=*/nullptr, &wall_timer, model); } else { // Normal sequential run. response = SolveCpModelInternal( diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index 68dc391eec..abe3c75132 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -70,37 +70,13 @@ void AddCumulativeRelaxation(const std::vector& x, model->Add(Cumulative(x, sizes, capacity)); } -#define RETURN_IF_FALSE(f) \ - if (!(f)) return false; +namespace { -// ------ Base class ----- - -NonOverlappingRectanglesBasePropagator::NonOverlappingRectanglesBasePropagator( - const std::vector& x, - const std::vector& y, bool strict, Model* model, - IntegerTrail* integer_trail) - : num_boxes_(x.size()), - x_(x, model), - y_(y, model), - strict_(strict), - integer_trail_(integer_trail) { - CHECK_GT(num_boxes_, 0); -} - -NonOverlappingRectanglesBasePropagator:: - ~NonOverlappingRectanglesBasePropagator() {} - -void NonOverlappingRectanglesBasePropagator::FillCachedAreas() { - cached_areas_.resize(num_boxes_); - for (int box = 0; box < num_boxes_; ++box) { - // We assume that the min-size of a box never changes. - cached_areas_[box] = x_.DurationMin(box) * y_.DurationMin(box); - } -} - -// We maximize the number of trailing bits set to 0 within a range. -IntegerValue NonOverlappingRectanglesBasePropagator::FindCanonicalValue( - IntegerValue lb, IntegerValue ub) { +// We want for different propagation to reuse as much as possible the same +// line. The idea behind this is to compute the 'canonical' line to use +// when explaining that boxes overlap on the 'y_dim' dimension. We compute +// the multiple of the biggest power of two that is common to all boxes. +IntegerValue FindCanonicalValue(IntegerValue lb, IntegerValue ub) { if (lb == ub) return lb; if (lb <= 0 && ub > 0) return IntegerValue(0); if (lb < 0 && ub <= 0) { @@ -121,173 +97,75 @@ IntegerValue NonOverlappingRectanglesBasePropagator::FindCanonicalValue( return candidate; } -std::vector> -NonOverlappingRectanglesBasePropagator::SplitDisjointBoxes( - std::vector boxes, SchedulingConstraintHelper* x_dim) { - std::vector> result(1); +std::vector> SplitDisjointBoxes( + absl::Span boxes, SchedulingConstraintHelper* x_dim) { + std::vector> result; std::sort(boxes.begin(), boxes.end(), [x_dim](int a, int b) { - return x_dim->StartMin(a) < x_dim->StartMin(b); + return x_dim->StartMin(a) < x_dim->StartMin(b) || + (x_dim->StartMin(a) == x_dim->StartMin(b) && a < b); }); - result.back().push_back(boxes[0]); + int current_start = 0; + std::size_t current_length = 1; IntegerValue current_max_end = x_dim->EndMax(boxes[0]); for (int b = 1; b < boxes.size(); ++b) { const int box = boxes[b]; if (x_dim->StartMin(box) < current_max_end) { // Merge. - result.back().push_back(box); + current_length++; current_max_end = std::max(current_max_end, x_dim->EndMax(box)); } else { - if (result.back().size() == 1) { - // Overwrite - result.back().clear(); - } else { - result.push_back(std::vector()); + if (current_length > 1) { // Ignore lists of size 1. + result.push_back({&boxes[current_start], current_length}); } - result.back().push_back(box); + current_start = b; + current_length = 1; current_max_end = x_dim->EndMax(box); } } + // Push last span. + if (current_length > 1) { + result.push_back({&boxes[current_start], current_length}); + } return result; } -bool NonOverlappingRectanglesBasePropagator:: - FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - SchedulingConstraintHelper* x_dim, SchedulingConstraintHelper* y_dim, - std::function& boxes)> inner_propagate) { - // Restore the two dimensions in a sane state. - x_dim->SetTimeDirection(true); - x_dim->ClearOtherHelper(); - x_dim->SetAllIntervalsVisible(); - y_dim->SetTimeDirection(true); - y_dim->ClearOtherHelper(); - y_dim->SetAllIntervalsVisible(); +} // namespace - std::map> event_to_overlapping_boxes; - std::set events; - - std::vector active_boxes; - - for (int box = 0; box < num_boxes_; ++box) { - if (cached_areas_[box] == 0 && !strict_) continue; - const IntegerValue start_max = y_dim->StartMax(box); - const IntegerValue end_min = y_dim->EndMin(box); - if (start_max < end_min) { - events.insert(start_max); - active_boxes.push_back(box); - } - } - - if (active_boxes.size() < 2) return true; - - for (const int box : active_boxes) { - const IntegerValue start_max = y_dim->StartMax(box); - const IntegerValue end_min = y_dim->EndMin(box); - - for (const IntegerValue t : events) { - if (t < start_max) continue; - if (t >= end_min) break; - event_to_overlapping_boxes[t].push_back(box); - } - } - - std::vector events_to_remove; - std::vector previous_overlapping_boxes; - IntegerValue previous_event(-1); - for (const auto& it : event_to_overlapping_boxes) { - const IntegerValue current_event = it.first; - const std::vector& current_overlapping_boxes = it.second; - if (current_overlapping_boxes.size() < 2) { - events_to_remove.push_back(current_event); - continue; - } - if (!previous_overlapping_boxes.empty()) { - if (std::includes(previous_overlapping_boxes.begin(), - previous_overlapping_boxes.end(), - current_overlapping_boxes.begin(), - current_overlapping_boxes.end())) { - events_to_remove.push_back(current_event); - } - } - previous_event = current_event; - previous_overlapping_boxes = current_overlapping_boxes; - } - - for (const IntegerValue event : events_to_remove) { - event_to_overlapping_boxes.erase(event); - } - - std::set> reduced_overlapping_boxes; - for (const auto& it : event_to_overlapping_boxes) { - std::vector> disjoint_boxes = - SplitDisjointBoxes(it.second, x_dim); - for (std::vector& sub_boxes : disjoint_boxes) { - if (sub_boxes.size() > 1) { - std::sort(sub_boxes.begin(), sub_boxes.end()); - reduced_overlapping_boxes.insert(sub_boxes); - } - } - } - - for (const std::vector& boxes : reduced_overlapping_boxes) { - // Collect the common overlapping coordinates of all boxes. - IntegerValue lb(kint64min); - IntegerValue ub(kint64max); - for (const int box : boxes) { - lb = std::max(lb, y_dim->StartMax(box)); - ub = std::min(ub, y_dim->EndMin(box) - 1); - } - CHECK_LE(lb, ub); - - // TODO(user): We should scan the integer trail to find the oldest - // non-empty common interval. Then we can pick the canonical value within - // it. - - // We want for different propagation to reuse as much as possible the same - // line. The idea behind this is to compute the 'canonical' line to use - // when explaining that boxes overlap on the 'y_dim' dimension. We compute - // the multiple of the biggest power of two that is common to all boxes. - const IntegerValue line_to_use_for_reason = FindCanonicalValue(lb, ub); - - // Setup x_dim for propagation. - x_dim->SetOtherHelper(y_dim, line_to_use_for_reason); - x_dim->SetVisibleIntervals(boxes); - - RETURN_IF_FALSE(inner_propagate(boxes)); - } - return true; -} +#define RETURN_IF_FALSE(f) \ + if (!(f)) return false; NonOverlappingRectanglesEnergyPropagator:: NonOverlappingRectanglesEnergyPropagator( const std::vector& x, - const std::vector& y, bool strict, Model* model, - IntegerTrail* integer_trail) - : NonOverlappingRectanglesBasePropagator(x, y, strict, model, - integer_trail) {} + const std::vector& y, Model* model) + : x_(x, model), y_(y, model) {} NonOverlappingRectanglesEnergyPropagator:: ~NonOverlappingRectanglesEnergyPropagator() {} bool NonOverlappingRectanglesEnergyPropagator::Propagate() { - FillCachedAreas(); + cached_areas_.resize(x_.NumTasks()); - std::vector all_boxes; - for (int box = 0; box < num_boxes_; ++box) { + active_boxes_.clear(); + for (int box = 0; box < x_.NumTasks(); ++box) { + cached_areas_[box] = x_.DurationMin(box) * y_.DurationMin(box); if (cached_areas_[box] == 0) continue; - all_boxes.push_back(box); + active_boxes_.push_back(box); } - if (all_boxes.empty()) return true; + if (active_boxes_.empty()) return true; - const std::vector> x_split = - SplitDisjointBoxes(all_boxes, &x_); - for (const std::vector& x_boxes : x_split) { + // const std::vector> x_split = + // SplitDisjointBoxes({&active_boxes_[0], active_boxes_.size()}, &x_); + const std::vector> x_split = + SplitDisjointBoxes(absl::MakeSpan(active_boxes_), &x_); + for (absl::Span x_boxes : x_split) { if (x_boxes.size() <= 1) continue; - const std::vector> y_split = + const std::vector> y_split = SplitDisjointBoxes(x_boxes, &y_); - for (const std::vector& y_boxes : y_split) { + for (absl::Span y_boxes : y_split) { if (y_boxes.size() <= 1) continue; for (const int box : y_boxes) { RETURN_IF_FALSE(FailWhenEnergyIsTooLarge(box, y_boxes)); @@ -298,22 +176,22 @@ bool NonOverlappingRectanglesEnergyPropagator::Propagate() { return true; } -void NonOverlappingRectanglesEnergyPropagator::RegisterWith( +int NonOverlappingRectanglesEnergyPropagator::RegisterWith( GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); x_.WatchAllTasks(id, watcher); y_.WatchAllTasks(id, watcher); - watcher->SetPropagatorPriority(id, 2); + return id; } -void NonOverlappingRectanglesEnergyPropagator::SortNeighbors( - int box, const std::vector& local_boxes) { +void NonOverlappingRectanglesEnergyPropagator::SortBoxesIntoNeighbors( + int box, absl::Span local_boxes) { auto max_span = [](IntegerValue min_a, IntegerValue max_a, IntegerValue min_b, IntegerValue max_b) { return std::max(max_a, max_b) - std::min(min_a, min_b) + 1; }; - cached_distance_to_bounding_box_.assign(num_boxes_, IntegerValue(0)); + cached_distance_to_bounding_box_.assign(x_.NumTasks(), IntegerValue(0)); neighbors_.clear(); const IntegerValue box_x_min = x_.StartMin(box); const IntegerValue box_x_max = x_.EndMax(box); @@ -341,10 +219,10 @@ void NonOverlappingRectanglesEnergyPropagator::SortNeighbors( } bool NonOverlappingRectanglesEnergyPropagator::FailWhenEnergyIsTooLarge( - int box, const std::vector& local_boxes) { + int box, absl::Span local_boxes) { // Note that we only consider the smallest dimension of each boxes here. + SortBoxesIntoNeighbors(box, local_boxes); - SortNeighbors(box, local_boxes); IntegerValue area_min_x = x_.StartMin(box); IntegerValue area_max_x = x_.EndMax(box); IntegerValue area_min_y = y_.StartMin(box); @@ -399,93 +277,241 @@ bool NonOverlappingRectanglesEnergyPropagator::FailWhenEnergyIsTooLarge( return true; } -NonOverlappingRectanglesFastPropagator::NonOverlappingRectanglesFastPropagator( - const std::vector& x, - const std::vector& y, bool strict, Model* model, - IntegerTrail* integer_trail) - : NonOverlappingRectanglesBasePropagator(x, y, strict, model, - integer_trail), - x_overload_checker_(true, &x_), - y_overload_checker_(true, &y_), - forward_x_detectable_precedences_(true, &x_), - backward_x_detectable_precedences_(false, &x_), - forward_y_detectable_precedences_(true, &y_), - backward_y_detectable_precedences_(false, &y_) {} - -NonOverlappingRectanglesFastPropagator:: - ~NonOverlappingRectanglesFastPropagator() {} - -bool NonOverlappingRectanglesFastPropagator::Propagate() { - FillCachedAreas(); - - // Reach fix-point on fast propagators. - RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - &x_, &y_, [this](const std::vector& boxes) { - if (boxes.size() == 2) { - // In that case, we can use simpler algorithms. - // Note that this case happens frequently (~30% of all calls to this - // method according to our tests). - RETURN_IF_FALSE(PropagateTwoBoxes(boxes[0], boxes[1], &x_)); - } else { - RETURN_IF_FALSE(x_overload_checker_.Propagate()); - RETURN_IF_FALSE(forward_x_detectable_precedences_.Propagate()); - RETURN_IF_FALSE(backward_x_detectable_precedences_.Propagate()); - } - return true; - })); - - // We can actually swap dimensions to propagate vertically. - RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - &y_, &x_, [this](const std::vector& boxes) { - if (boxes.size() == 2) { - // In that case, we can use simpler algorithms. - // Note that this case happens frequently (~30% of all calls to this - // method according to our tests). - RETURN_IF_FALSE(PropagateTwoBoxes(boxes[0], boxes[1], &y_)); - } else { - RETURN_IF_FALSE(y_overload_checker_.Propagate()); - RETURN_IF_FALSE(forward_y_detectable_precedences_.Propagate()); - RETURN_IF_FALSE(backward_y_detectable_precedences_.Propagate()); - } - return true; - })); - return true; +NonOverlappingRectanglesDisjunctivePropagator:: + NonOverlappingRectanglesDisjunctivePropagator( + const std::vector& x, + const std::vector& y, bool strict, + bool slow_propagators, Model* model) + : x_intervals_(x), + y_intervals_(y), + x_(x, model), + y_(y, model), + strict_(strict), + slow_propagators_(slow_propagators), + overload_checker_(true, &x_), + forward_detectable_precedences_(true, &x_), + backward_detectable_precedences_(false, &x_), + forward_not_last_(true, &x_), + backward_not_last_(false, &x_), + forward_edge_finding_(true, &x_), + backward_edge_finding_(false, &x_) { + CHECK_GT(x_.NumTasks(), 0); } -void NonOverlappingRectanglesFastPropagator::RegisterWith( +NonOverlappingRectanglesDisjunctivePropagator:: + ~NonOverlappingRectanglesDisjunctivePropagator() {} + +int NonOverlappingRectanglesDisjunctivePropagator::RegisterWith( GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); x_.WatchAllTasks(id, watcher); y_.WatchAllTasks(id, watcher); - watcher->SetPropagatorPriority(id, 3); + return id; +} + +bool NonOverlappingRectanglesDisjunctivePropagator:: + FindBoxesThatMustOverlapAHorizontalLineAndPropagate( + const std::vector& x_intervals, + const std::vector& y_intervals, + std::function inner_propagate) { + // Restore the two dimensions in a sane state. + x_.SetTimeDirection(true); + x_.ClearOtherHelper(); + x_.Init(x_intervals); + + y_.SetTimeDirection(true); + y_.ClearOtherHelper(); + y_.Init(y_intervals); + + // Compute relevant events (line in the y dimension). + absl::flat_hash_map> + event_to_overlapping_boxes; + std::set events; + + std::vector active_boxes; + + for (int box = 0; box < x_intervals.size(); ++box) { + if ((x_.DurationMin(box) == 0 || y_.DurationMin(box) == 0) && !strict_) { + continue; + } + + const IntegerValue start_max = y_.StartMax(box); + const IntegerValue end_min = y_.EndMin(box); + if (start_max < end_min) { + events.insert(start_max); + active_boxes.push_back(box); + } + } + + // Less than 2 boxes, no propagation. + if (active_boxes.size() < 2) return true; + + // Add boxes to the event lists they always overlap with. + for (const int box : active_boxes) { + const IntegerValue start_max = y_.StartMax(box); + const IntegerValue end_min = y_.EndMin(box); + + for (const IntegerValue t : events) { + if (t < start_max) continue; + if (t >= end_min) break; + event_to_overlapping_boxes[t].push_back(box); + } + } + + // Scan events chronologically to remove events where there is only one + // mandatory box, or dominated events lists. + std::vector events_to_remove; + std::vector previous_overlapping_boxes; + IntegerValue previous_event(-1); + for (const IntegerValue current_event : events) { + const std::vector& current_overlapping_boxes = + event_to_overlapping_boxes[current_event]; + if (current_overlapping_boxes.size() < 2) { + events_to_remove.push_back(current_event); + continue; + } + if (!previous_overlapping_boxes.empty()) { + // In case we just add one box to the previous event. + if (std::includes(current_overlapping_boxes.begin(), + current_overlapping_boxes.end(), + previous_overlapping_boxes.begin(), + previous_overlapping_boxes.end())) { + events_to_remove.push_back(previous_event); + continue; + } + } + + previous_event = current_event; + previous_overlapping_boxes = current_overlapping_boxes; + } + + for (const IntegerValue event : events_to_remove) { + events.erase(event); + } + + // Split lists of boxes into disjoint set of boxes (w.r.t. overlap). + absl::flat_hash_set> reduced_overlapping_boxes; + std::vector> boxes_to_propagate; + std::vector> disjoint_boxes; + for (const IntegerValue event : events) { + disjoint_boxes = SplitDisjointBoxes( + absl::MakeSpan(event_to_overlapping_boxes[event]), &x_); + for (absl::Span sub_boxes : disjoint_boxes) { + if (sub_boxes.size() > 1) { + // Boxes are sorted in a stable manner in the Split method. + const auto& insertion = reduced_overlapping_boxes.insert(sub_boxes); + if (insertion.second) boxes_to_propagate.push_back(sub_boxes); + } + } + } + + // And finally propagate. + // TODO(user): Sorting of boxes seems influential on the performance. Test. + for (const absl::Span boxes : boxes_to_propagate) { + std::vector reduced_x; + std::vector reduced_y; + for (const int box : boxes) { + reduced_x.push_back(x_intervals[box]); + reduced_y.push_back(y_intervals[box]); + } + + x_.Init(reduced_x); + y_.Init(reduced_y); + + // Collect the common overlapping coordinates of all boxes. + IntegerValue lb(kint64min); + IntegerValue ub(kint64max); + for (int i = 0; i < reduced_x.size(); ++i) { + lb = std::max(lb, y_.StartMax(i)); + ub = std::min(ub, y_.EndMin(i) - 1); + } + CHECK_LE(lb, ub); + + // TODO(user): We should scan the integer trail to find the oldest + // non-empty common interval. Then we can pick the canonical value within + // it. + + // We want for different propagation to reuse as much as possible the same + // line. The idea behind this is to compute the 'canonical' line to use + // when explaining that boxes overlap on the 'y_dim' dimension. We compute + // the multiple of the biggest power of two that is common to all boxes. + const IntegerValue line_to_use_for_reason = FindCanonicalValue(lb, ub); + + // Setup x_dim for propagation. + x_.SetOtherHelper(&y_, line_to_use_for_reason); + + RETURN_IF_FALSE(inner_propagate()); + } + + return true; +} + +bool NonOverlappingRectanglesDisjunctivePropagator::Propagate() { + const auto slow_propagate = [this]() { + if (x_.NumTasks() <= 2) return true; + RETURN_IF_FALSE(forward_not_last_.Propagate()); + RETURN_IF_FALSE(backward_not_last_.Propagate()); + RETURN_IF_FALSE(backward_edge_finding_.Propagate()); + RETURN_IF_FALSE(forward_edge_finding_.Propagate()); + return true; + }; + + const auto fast_propagate = [this]() { + if (x_.NumTasks() == 2) { + // In that case, we can use simpler algorithms. + // Note that this case happens frequently (~30% of all calls to this + // method according to our tests). + RETURN_IF_FALSE(PropagateTwoBoxes()); + } else { + RETURN_IF_FALSE(overload_checker_.Propagate()); + RETURN_IF_FALSE(forward_detectable_precedences_.Propagate()); + RETURN_IF_FALSE(backward_detectable_precedences_.Propagate()); + } + return true; + }; + + if (slow_propagators_) { + RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( + x_intervals_, y_intervals_, slow_propagate)); + + // We can actually swap dimensions to propagate vertically. + RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( + y_intervals_, x_intervals_, slow_propagate)); + } else { + RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( + x_intervals_, y_intervals_, fast_propagate)); + + // We can actually swap dimensions to propagate vertically. + RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( + y_intervals_, x_intervals_, fast_propagate)); + } + return true; } // Specialized propagation on only two boxes that must intersect with the // given y_line_for_reason. -bool NonOverlappingRectanglesFastPropagator::PropagateTwoBoxes( - int b1, int b2, SchedulingConstraintHelper* x_dim) { +bool NonOverlappingRectanglesDisjunctivePropagator::PropagateTwoBoxes() { // For each direction and each order, we test if the boxes can be disjoint. - const int state = (x_dim->EndMin(b1) <= x_dim->StartMax(b2)) + - 2 * (x_dim->EndMin(b2) <= x_dim->StartMax(b1)); + const int state = + (x_.EndMin(0) <= x_.StartMax(1)) + 2 * (x_.EndMin(1) <= x_.StartMax(0)); - const auto left_box_before_right_box = [](int left, int right, - SchedulingConstraintHelper* x_dim) { + const auto left_box_before_right_box = [this](int left, int right) { // left box pushes right box. - const IntegerValue left_end_min = x_dim->EndMin(left); - if (left_end_min > x_dim->StartMin(right)) { - x_dim->ClearReason(); - x_dim->AddReasonForBeingBefore(left, right); - x_dim->AddEndMinReason(left, left_end_min); - RETURN_IF_FALSE(x_dim->IncreaseStartMin(right, left_end_min)); + const IntegerValue left_end_min = x_.EndMin(left); + if (left_end_min > x_.StartMin(right)) { + x_.ClearReason(); + x_.AddReasonForBeingBefore(left, right); + x_.AddEndMinReason(left, left_end_min); + RETURN_IF_FALSE(x_.IncreaseStartMin(right, left_end_min)); } // right box pushes left box. - const IntegerValue right_start_max = x_dim->StartMax(right); - if (right_start_max < x_dim->EndMax(left)) { - x_dim->ClearReason(); - x_dim->AddReasonForBeingBefore(left, right); - x_dim->AddStartMaxReason(right, right_start_max); - RETURN_IF_FALSE(x_dim->DecreaseEndMax(left, right_start_max)); + const IntegerValue right_start_max = x_.StartMax(right); + if (right_start_max < x_.EndMax(left)) { + x_.ClearReason(); + x_.AddReasonForBeingBefore(left, right); + x_.AddStartMaxReason(right, right_start_max); + RETURN_IF_FALSE(x_.DecreaseEndMax(left, right_start_max)); } return true; @@ -493,16 +519,16 @@ bool NonOverlappingRectanglesFastPropagator::PropagateTwoBoxes( switch (state) { case 0: { // Conflict. - x_dim->ClearReason(); - x_dim->AddReasonForBeingBefore(b1, b2); - x_dim->AddReasonForBeingBefore(b2, b1); - return x_dim->ReportConflict(); + x_.ClearReason(); + x_.AddReasonForBeingBefore(0, 1); + x_.AddReasonForBeingBefore(1, 0); + return x_.ReportConflict(); } case 1: { // b1 is left of b2. - return left_box_before_right_box(b1, b2, x_dim); + return left_box_before_right_box(0, 1); } case 2: { // b2 is left of b1. - return left_box_before_right_box(b2, b1, x_dim); + return left_box_before_right_box(1, 0); } default: { // Nothing to deduce. return true; @@ -510,59 +536,6 @@ bool NonOverlappingRectanglesFastPropagator::PropagateTwoBoxes( } } -NonOverlappingRectanglesSlowPropagator::NonOverlappingRectanglesSlowPropagator( - const std::vector& x, - const std::vector& y, bool strict, Model* model, - IntegerTrail* integer_trail) - : NonOverlappingRectanglesBasePropagator(x, y, strict, model, - integer_trail), - forward_x_not_last_(true, &x_), - backward_x_not_last_(false, &x_), - forward_y_not_last_(true, &y_), - backward_y_not_last_(false, &y_), - forward_x_edge_finding_(true, &x_), - backward_x_edge_finding_(false, &x_), - forward_y_edge_finding_(true, &y_), - backward_y_edge_finding_(false, &y_) {} - -NonOverlappingRectanglesSlowPropagator:: - ~NonOverlappingRectanglesSlowPropagator() {} - -bool NonOverlappingRectanglesSlowPropagator::Propagate() { - FillCachedAreas(); - - RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - &x_, &y_, [this](const std::vector& boxes) { - if (boxes.size() <= 2) return true; - RETURN_IF_FALSE(forward_x_not_last_.Propagate()); - RETURN_IF_FALSE(backward_x_not_last_.Propagate()); - RETURN_IF_FALSE(backward_x_edge_finding_.Propagate()); - RETURN_IF_FALSE(forward_x_edge_finding_.Propagate()); - return true; - })); - - // We can actually swap dimensions to propagate vertically. - RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - &y_, &x_, [this](const std::vector& boxes) { - if (boxes.size() <= 2) return true; - RETURN_IF_FALSE(forward_y_not_last_.Propagate()); - RETURN_IF_FALSE(backward_y_not_last_.Propagate()); - RETURN_IF_FALSE(backward_y_edge_finding_.Propagate()); - RETURN_IF_FALSE(forward_y_edge_finding_.Propagate()); - return true; - })); - - return true; -} - -void NonOverlappingRectanglesSlowPropagator::RegisterWith( - GenericLiteralWatcher* watcher) { - const int id = watcher->Register(this); - x_.WatchAllTasks(id, watcher); - y_.WatchAllTasks(id, watcher); - watcher->SetPropagatorPriority(id, 4); -} - #undef RETURN_IF_FALSE } // namespace sat } // namespace operations_research diff --git a/ortools/sat/diffn.h b/ortools/sat/diffn.h index 4bf7eac380..d788076e2f 100644 --- a/ortools/sat/diffn.h +++ b/ortools/sat/diffn.h @@ -29,122 +29,79 @@ namespace operations_research { namespace sat { -// Non overlapping rectangles. -class NonOverlappingRectanglesBasePropagator : public PropagatorInterface { - public: - // The strict parameters indicates how to place zero width or zero height - // boxes. If strict is true, these boxes must not 'cross' another box, and are - // pushed by the other boxes. - NonOverlappingRectanglesBasePropagator(const std::vector& x, - const std::vector& y, - bool strict, Model* model, - IntegerTrail* integer_trail); - ~NonOverlappingRectanglesBasePropagator() override; - - protected: - void FillCachedAreas(); - IntegerValue FindCanonicalValue(IntegerValue lb, IntegerValue ub); - bool FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - SchedulingConstraintHelper* x_dim, SchedulingConstraintHelper* y_dim, - std::function& boxes)> inner_propagate); - std::vector> SplitDisjointBoxes( - std::vector boxes, SchedulingConstraintHelper* x_dim); - - const int num_boxes_; - SchedulingConstraintHelper x_; - SchedulingConstraintHelper y_; - const bool strict_; - IntegerTrail* integer_trail_; - std::vector cached_areas_; - - private: - DISALLOW_COPY_AND_ASSIGN(NonOverlappingRectanglesBasePropagator); -}; - // Propagates using a box energy reasoning. -class NonOverlappingRectanglesEnergyPropagator - : public NonOverlappingRectanglesBasePropagator { +class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { public: // The strict parameters indicates how to place zero width or zero height // boxes. If strict is true, these boxes must not 'cross' another box, and are // pushed by the other boxes. NonOverlappingRectanglesEnergyPropagator( const std::vector& x, - const std::vector& y, bool strict, Model* model, - IntegerTrail* integer_trail); + const std::vector& y, Model* model); ~NonOverlappingRectanglesEnergyPropagator() override; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); + int RegisterWith(GenericLiteralWatcher* watcher); private: - void SortNeighbors(int box, const std::vector& local_boxes); - bool FailWhenEnergyIsTooLarge(int box, const std::vector& local_boxes); + void SortBoxesIntoNeighbors(int box, absl::Span local_boxes); + bool FailWhenEnergyIsTooLarge(int box, absl::Span local_boxes); + SchedulingConstraintHelper x_; + SchedulingConstraintHelper y_; + std::vector cached_areas_; std::vector neighbors_; std::vector cached_distance_to_bounding_box_; + std::vector active_boxes_; - DISALLOW_COPY_AND_ASSIGN(NonOverlappingRectanglesEnergyPropagator); + NonOverlappingRectanglesEnergyPropagator( + const NonOverlappingRectanglesEnergyPropagator&) = delete; + NonOverlappingRectanglesEnergyPropagator& operator=( + const NonOverlappingRectanglesEnergyPropagator&) = delete; }; -// Embeds the overload checker and the detectable precedences propagators from -// the disjunctive constraint. -class NonOverlappingRectanglesFastPropagator - : public NonOverlappingRectanglesBasePropagator { +// Non overlapping rectangles. +class NonOverlappingRectanglesDisjunctivePropagator + : public PropagatorInterface { public: // The strict parameters indicates how to place zero width or zero height // boxes. If strict is true, these boxes must not 'cross' another box, and are // pushed by the other boxes. - NonOverlappingRectanglesFastPropagator(const std::vector& x, - const std::vector& y, - bool strict, Model* model, - IntegerTrail* integer_trail); - ~NonOverlappingRectanglesFastPropagator() override; + // The slow_propagators select which disjunctive algorithms to propagate. + NonOverlappingRectanglesDisjunctivePropagator( + const std::vector& x, + const std::vector& y, bool strict, + bool slow_propagators, Model* model); + ~NonOverlappingRectanglesDisjunctivePropagator() override; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); + int RegisterWith(GenericLiteralWatcher* watcher); private: - bool PropagateTwoBoxes(int b1, int b2, SchedulingConstraintHelper* x_dim); + bool FindBoxesThatMustOverlapAHorizontalLineAndPropagate( + const std::vector& x_intervals, + const std::vector& y_intervals, + std::function inner_propagate); + bool PropagateTwoBoxes(); - DisjunctiveOverloadChecker x_overload_checker_; - DisjunctiveOverloadChecker y_overload_checker_; - DisjunctiveDetectablePrecedences forward_x_detectable_precedences_; - DisjunctiveDetectablePrecedences backward_x_detectable_precedences_; - DisjunctiveDetectablePrecedences forward_y_detectable_precedences_; - DisjunctiveDetectablePrecedences backward_y_detectable_precedences_; + const std::vector x_intervals_; + const std::vector y_intervals_; + SchedulingConstraintHelper x_; + SchedulingConstraintHelper y_; + const bool strict_; + const bool slow_propagators_; + DisjunctiveOverloadChecker overload_checker_; + DisjunctiveDetectablePrecedences forward_detectable_precedences_; + DisjunctiveDetectablePrecedences backward_detectable_precedences_; + DisjunctiveNotLast forward_not_last_; + DisjunctiveNotLast backward_not_last_; + DisjunctiveEdgeFinding forward_edge_finding_; + DisjunctiveEdgeFinding backward_edge_finding_; - DISALLOW_COPY_AND_ASSIGN(NonOverlappingRectanglesFastPropagator); -}; - -// Embeds the not last and edge finder propagators from the disjunctive -// constraint. -class NonOverlappingRectanglesSlowPropagator - : public NonOverlappingRectanglesBasePropagator { - public: - // The strict parameters indicates how to place zero width or zero height - // boxes. If strict is true, these boxes must not 'cross' another box, and are - // pushed by the other boxes. - NonOverlappingRectanglesSlowPropagator(const std::vector& x, - const std::vector& y, - bool strict, Model* model, - IntegerTrail* integer_trail); - ~NonOverlappingRectanglesSlowPropagator() override; - - bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); - - private: - DisjunctiveNotLast forward_x_not_last_; - DisjunctiveNotLast backward_x_not_last_; - DisjunctiveNotLast forward_y_not_last_; - DisjunctiveNotLast backward_y_not_last_; - DisjunctiveEdgeFinding forward_x_edge_finding_; - DisjunctiveEdgeFinding backward_x_edge_finding_; - DisjunctiveEdgeFinding forward_y_edge_finding_; - DisjunctiveEdgeFinding backward_y_edge_finding_; - - DISALLOW_COPY_AND_ASSIGN(NonOverlappingRectanglesSlowPropagator); + NonOverlappingRectanglesDisjunctivePropagator( + const NonOverlappingRectanglesDisjunctivePropagator&) = delete; + NonOverlappingRectanglesDisjunctivePropagator& operator=( + const NonOverlappingRectanglesDisjunctivePropagator&) = delete; }; // Add a cumulative relaxation. That is, on one direction, it does not enforce @@ -155,60 +112,31 @@ void AddCumulativeRelaxation(const std::vector& x, // Enforces that the boxes with corners in (x, y), (x + dx, y), (x, y + dy) // and (x + dx, y + dy) do not overlap. -// If one box has a zero dimension, then it can be placed anywhere. +// If strict is true, and if one box has a zero dimension, it still cannot +// intersect another box. inline std::function NonOverlappingRectangles( const std::vector& x, - const std::vector& y) { + const std::vector& y, bool is_strict) { return [=](Model* model) { + GenericLiteralWatcher* const watcher = + model->GetOrCreate(); + NonOverlappingRectanglesEnergyPropagator* energy_constraint = - new NonOverlappingRectanglesEnergyPropagator( - x, y, false, model, model->GetOrCreate()); - energy_constraint->RegisterWith( - model->GetOrCreate()); + new NonOverlappingRectanglesEnergyPropagator(x, y, model); + + watcher->SetPropagatorPriority(energy_constraint->RegisterWith(watcher), 3); model->TakeOwnership(energy_constraint); - NonOverlappingRectanglesFastPropagator* fast_constraint = - new NonOverlappingRectanglesFastPropagator( - x, y, false, model, model->GetOrCreate()); - fast_constraint->RegisterWith(model->GetOrCreate()); + NonOverlappingRectanglesDisjunctivePropagator* fast_constraint = + new NonOverlappingRectanglesDisjunctivePropagator( + x, y, is_strict, /*slow_propagators=*/false, model); + watcher->SetPropagatorPriority(fast_constraint->RegisterWith(watcher), 3); model->TakeOwnership(fast_constraint); - NonOverlappingRectanglesSlowPropagator* slow_constraint = - new NonOverlappingRectanglesSlowPropagator( - x, y, false, model, model->GetOrCreate()); - slow_constraint->RegisterWith(model->GetOrCreate()); - - model->TakeOwnership(slow_constraint); - - AddCumulativeRelaxation(x, y, model); - AddCumulativeRelaxation(y, x, model); - }; -} - -// Enforces that the boxes with corners in (x, y), (x + dx, y), (x, y + dy) -// and (x + dx, y + dy) do not overlap. -// If one box has a zero dimension, it still cannot intersect another box. -inline std::function StrictNonOverlappingRectangles( - const std::vector& x, - const std::vector& y) { - return [=](Model* model) { - NonOverlappingRectanglesEnergyPropagator* energy_constraint = - new NonOverlappingRectanglesEnergyPropagator( - x, y, true, model, model->GetOrCreate()); - energy_constraint->RegisterWith( - model->GetOrCreate()); - model->TakeOwnership(energy_constraint); - - NonOverlappingRectanglesFastPropagator* fast_constraint = - new NonOverlappingRectanglesFastPropagator( - x, y, true, model, model->GetOrCreate()); - fast_constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(fast_constraint); - - NonOverlappingRectanglesSlowPropagator* slow_constraint = - new NonOverlappingRectanglesSlowPropagator( - x, y, true, model, model->GetOrCreate()); - slow_constraint->RegisterWith(model->GetOrCreate()); + NonOverlappingRectanglesDisjunctivePropagator* slow_constraint = + new NonOverlappingRectanglesDisjunctivePropagator( + x, y, is_strict, /*slow_propagators=*/true, model); + watcher->SetPropagatorPriority(slow_constraint->RegisterWith(watcher), 4); model->TakeOwnership(slow_constraint); AddCumulativeRelaxation(x, y, model); diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index a8fa7fb189..367e5b1580 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -20,6 +20,7 @@ #include "ortools/sat/all_different.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" +#include "ortools/util/sort.h" namespace operations_research { namespace sat { @@ -153,7 +154,6 @@ void TaskSet::AddEntry(const Entry& e) { if (j <= optimized_restart_) optimized_restart_ = 0; } -// Note that we can keep the optimized_restart_ at its current value here. void TaskSet::NotifyEntryIsNowLastIfPresent(const Entry& e) { const int size = sorted_tasks_.size(); for (int i = 0;; ++i) { @@ -163,6 +163,8 @@ void TaskSet::NotifyEntryIsNowLastIfPresent(const Entry& e) { break; } } + + optimized_restart_ = sorted_tasks_.size(); sorted_tasks_.push_back(e); DCHECK(std::is_sorted(sorted_tasks_.begin(), sorted_tasks_.end())); } @@ -259,6 +261,63 @@ int DisjunctiveWithTwoItems::RegisterWith(GenericLiteralWatcher* watcher) { return id; } +bool DisjunctiveOverloadChecker::Propagate() { + helper_->SetTimeDirection(time_direction_); + + // Split problem into independent part. + // + // Many propagators in this file use the same approach, we start by processing + // the task by increasing start-min, packing everything to the left. We then + // process each "independent" set of task separately. A task is independent + // from the one before it, if its start-min wasn't pushed. + // + // This way, we get one or more window [window_start, window_end] so that for + // all task in the window, [start_min, end_min] is inside the window, and the + // end min of any set of task to the left is <= window_start, and the + // start_min of any task to the right is >= end_min. + window_.clear(); + IntegerValue window_end = kMinIntegerValue; + IntegerValue relevant_end; + int relevant_size = 0; + for (const TaskTime task_time : helper_->TaskByIncreasingShiftedStartMin()) { + const int task = task_time.task_index; + if (helper_->IsAbsent(task)) continue; + + const IntegerValue start_min = task_time.time; + if (start_min < window_end) { + window_.push_back(task_time); + window_end += helper_->DurationMin(task); + if (window_end > helper_->EndMax(task)) { + relevant_size = window_.size(); + relevant_end = window_end; + } + continue; + } + + // Process current window. + // We don't need to process the end of the window (after relevant_size) + // because these interval can be greedily assembled in a feasible solution. + window_.resize(relevant_size); + if (relevant_size > 0 && !PropagateSubwindow(relevant_end)) { + return false; + } + + // Start of the next window. + window_.clear(); + window_.push_back(task_time); + window_end = start_min + helper_->DurationMin(task); + relevant_size = 0; + } + + // Process last window. + window_.resize(relevant_size); + if (relevant_size > 0 && !PropagateSubwindow(relevant_end)) { + return false; + } + + return true; +} + // TODO(user): Improve the Overload Checker using delayed insertion. // We insert events at the cost of O(log n) per insertion, and this is where // the algorithm spends most of its time, thus it is worth improving. @@ -266,36 +325,29 @@ int DisjunctiveWithTwoItems::RegisterWith(GenericLiteralWatcher* watcher) { // set. This is useless for the overload checker as is since we need to check // overload after every insertion, but we could use an upper bound of the // theta envelope to save us from checking the actual value. -bool DisjunctiveOverloadChecker::Propagate() { - helper_->SetTimeDirection(time_direction_); - - // Set up theta tree. - event_to_task_.clear(); - event_time_.clear(); - int num_events = 0; - for (const auto task_time : helper_->TaskByIncreasingShiftedStartMin()) { - const int task = task_time.task_index; - - // TODO(user): We need to take into account task with zero duration because - // in this constraint, such a task cannot be overlapped by other. However, - // we currently use the fact that the energy min is zero to detect that a - // task is present and non-optional in the theta_tree_. Fix. - if (helper_->IsAbsent(task) || helper_->DurationMin(task) == 0) { - task_to_event_[task] = -1; - continue; +bool DisjunctiveOverloadChecker::PropagateSubwindow( + IntegerValue global_window_end) { + // Set up theta tree and task_by_increasing_end_max_. + const int window_size = window_.size(); + theta_tree_.Reset(window_size); + task_by_increasing_end_max_.clear(); + for (int i = 0; i < window_size; ++i) { + // No point adding a task if its end_max is too large. + const int task = window_[i].task_index; + const IntegerValue end_max = helper_->EndMax(task); + if (end_max < global_window_end) { + task_to_event_[task] = i; + task_by_increasing_end_max_.push_back({task, end_max}); } - event_to_task_.push_back(task); - event_time_.push_back(task_time.time); - task_to_event_[task] = num_events; - num_events++; } - theta_tree_.Reset(num_events); - // Introduce events by nondecreasing end_max, check for overloads. - for (const auto task_time : - ::gtl::reversed_view(helper_->TaskByDecreasingEndMax())) { + // Introduce events by increasing end_max, check for overloads. + std::sort(task_by_increasing_end_max_.begin(), + task_by_increasing_end_max_.end()); + for (const auto task_time : task_by_increasing_end_max_) { const int current_task = task_time.task_index; - if (task_to_event_[current_task] == -1) continue; + DCHECK_NE(task_to_event_[current_task], -1); + DCHECK(!helper_->IsAbsent(current_task)); { const int current_event = task_to_event_[current_task]; @@ -304,11 +356,11 @@ bool DisjunctiveOverloadChecker::Propagate() { // TODO(user,user): Add max energy deduction for variable // durations by putting the energy_max here and modifying the code // dealing with the optional envelope greater than current_end below. - theta_tree_.AddOrUpdateEvent(current_event, event_time_[current_event], + theta_tree_.AddOrUpdateEvent(current_event, window_[current_event].time, energy_min, energy_min); } else { theta_tree_.AddOrUpdateOptionalEvent( - current_event, event_time_[current_event], energy_min); + current_event, window_[current_event].time, energy_min); } } @@ -318,13 +370,13 @@ bool DisjunctiveOverloadChecker::Propagate() { helper_->ClearReason(); const int critical_event = theta_tree_.GetMaxEventWithEnvelopeGreaterThan(current_end); - const IntegerValue window_start = event_time_[critical_event]; + const IntegerValue window_start = window_[critical_event].time; const IntegerValue window_end = theta_tree_.GetEnvelopeOf(critical_event) - 1; - for (int event = critical_event; event < num_events; event++) { + for (int event = critical_event; event < window_size; event++) { const IntegerValue energy_min = theta_tree_.EnergyMin(event); if (energy_min > 0) { - const int task = event_to_task_[event]; + const int task = window_[event].task_index; helper_->AddPresenceReason(task); helper_->AddEnergyAfterReason(task, energy_min, window_start); helper_->AddEndMaxReason(task, window_end); @@ -345,16 +397,16 @@ bool DisjunctiveOverloadChecker::Propagate() { theta_tree_.GetEventsWithOptionalEnvelopeGreaterThan( current_end, &critical_event, &optional_event, &available_energy); - const int optional_task = event_to_task_[optional_event]; + const int optional_task = window_[optional_event].task_index; const IntegerValue optional_duration_min = helper_->DurationMin(optional_task); - const IntegerValue window_start = event_time_[critical_event]; + const IntegerValue window_start = window_[critical_event].time; const IntegerValue window_end = current_end + optional_duration_min - available_energy - 1; - for (int event = critical_event; event < num_events; event++) { + for (int event = critical_event; event < window_size; event++) { const IntegerValue energy_min = theta_tree_.EnergyMin(event); if (energy_min > 0) { - const int task = event_to_task_[event]; + const int task = window_[event].task_index; helper_->AddPresenceReason(task); helper_->AddEnergyAfterReason(task, energy_min, window_start); helper_->AddEndMaxReason(task, window_end); @@ -373,46 +425,113 @@ bool DisjunctiveOverloadChecker::Propagate() { theta_tree_.RemoveEvent(optional_event); } } + return true; } int DisjunctiveOverloadChecker::RegisterWith(GenericLiteralWatcher* watcher) { // This propagator reach the fix point in one pass. const int id = watcher->Register(this); - helper_->WatchAllTasks(id, watcher); + helper_->SetTimeDirection(time_direction_); + helper_->WatchAllTasks(id, watcher, /*watch_start_max=*/false, + /*watch_end_max=*/true); return id; } bool DisjunctiveDetectablePrecedences::Propagate() { helper_->SetTimeDirection(time_direction_); - const auto& task_by_increasing_end_min = helper_->TaskByIncreasingEndMin(); - const auto& task_by_decreasing_start_max = - helper_->TaskByDecreasingStartMax(); - const int num_tasks = helper_->NumTasks(); - int queue_index = num_tasks - 1; + // Split problem into independent part. + // + // The "independent" window can be processed separately because for each of + // them, a task [start-min, end-min] is in the window [window_start, + // window_end]. So any task to the left of the window cannot push such + // task start_min, and any task to the right of the window will have a + // start_max >= end_min, so wouldn't be in detectable precedence. + window_.clear(); + IntegerValue window_end = kMinIntegerValue; + IntegerValue window_max_of_end_min = kMinIntegerValue; + for (const TaskTime task_time : helper_->TaskByIncreasingShiftedStartMin()) { + const int task = task_time.task_index; + if (helper_->IsAbsent(task)) continue; + + const IntegerValue start_min = task_time.time; + if (start_min < window_end) { + const IntegerValue duration_min = helper_->DurationMin(task); + const IntegerValue end_min = helper_->EndMin(task); + window_.push_back({task, end_min}); + window_end += duration_min; + window_max_of_end_min = std::max(window_max_of_end_min, end_min); + continue; + } + + // Process current window. + if (window_.size() > 1 && !PropagateSubwindow(window_max_of_end_min)) { + return false; + } + + // Start of the next window. + window_.clear(); + const IntegerValue duration_min = helper_->DurationMin(task); + const IntegerValue end_min = helper_->EndMin(task); + window_.push_back({task, end_min}); + window_end = start_min + duration_min; + window_max_of_end_min = end_min; + } + + if (window_.size() > 1 && !PropagateSubwindow(window_max_of_end_min)) { + return false; + } + + return true; +} + +bool DisjunctiveDetectablePrecedences::PropagateSubwindow( + IntegerValue max_end_min) { + task_by_increasing_start_max_.clear(); + for (const TaskTime entry : window_) { + const int task = entry.task_index; + const IntegerValue start_max = helper_->StartMax(task); + if (start_max < max_end_min && helper_->IsPresent(task)) { + task_by_increasing_start_max_.push_back({task, start_max}); + } + } + if (task_by_increasing_start_max_.empty()) return true; + std::sort(task_by_increasing_start_max_.begin(), + task_by_increasing_start_max_.end()); + + // The window is already sorted by shifted_start_min, so there is likely a + // good correlation, hence the incremental sort. + // + // TODO(user): Instead of end-min, we should use end-min if present. Same in + // other places. + auto& task_by_increasing_end_min = window_; + IncrementalSort(task_by_increasing_end_min.begin(), + task_by_increasing_end_min.end()); + task_set_.Clear(); + int queue_index = 0; + const int queue_size = task_by_increasing_start_max_.size(); for (const auto task_time : task_by_increasing_end_min) { const int t = task_time.task_index; const IntegerValue end_min = task_time.time; + DCHECK(!helper_->IsAbsent(t)); - if (helper_->IsAbsent(t)) continue; - - while (queue_index >= 0) { - const auto to_insert = task_by_decreasing_start_max[queue_index]; - const int task_index = to_insert.task_index; + while (queue_index < queue_size) { + const auto to_insert = task_by_increasing_start_max_[queue_index]; const IntegerValue start_max = to_insert.time; if (end_min <= start_max) break; - if (helper_->IsPresent(task_index)) { - task_set_.AddEntry({task_index, helper_->ShiftedStartMin(task_index), - helper_->DurationMin(task_index)}); - } - --queue_index; + + const int task_index = to_insert.task_index; + DCHECK(helper_->IsPresent(task_index)); + task_set_.AddEntry({task_index, helper_->ShiftedStartMin(task_index), + helper_->DurationMin(task_index)}); + ++queue_index; } - // task_set_ contains all the tasks that must be executed before t. - // They are in "dectable precedence" because their start_max is smaller than - // the end-min of t like so: + // task_set_ contains all the tasks that must be executed before t. They are + // in "detectable precedence" because their start_max is smaller than the + // end-min of t like so: // [(the task t) // (a task in task_set_)] // From there, we deduce that the start-min of t is greater or equal to the @@ -428,7 +547,7 @@ bool DisjunctiveDetectablePrecedences::Propagate() { // We need: // - StartMax(ct) < EndMin(t) for the detectable precedence. - // - StartMin(ct) > window_start for the end_min_of_critical_tasks reason. + // - StartMin(ct) >= window_start for end_min_of_critical_tasks. const IntegerValue window_start = sorted_tasks[critical_index].start_min; for (int i = critical_index; i < sorted_tasks.size(); ++i) { const int ct = sorted_tasks[i].task; @@ -463,7 +582,9 @@ bool DisjunctiveDetectablePrecedences::Propagate() { int DisjunctiveDetectablePrecedences::RegisterWith( GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - helper_->WatchAllTasks(id, watcher); + helper_->SetTimeDirection(time_direction_); + helper_->WatchAllTasks(id, watcher, /*watch_start_max=*/true, + /*watch_end_max=*/false); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); return id; } @@ -471,15 +592,19 @@ int DisjunctiveDetectablePrecedences::RegisterWith( bool DisjunctivePrecedences::Propagate() { helper_->SetTimeDirection(time_direction_); - const int num_tasks = helper_->NumTasks(); - for (int t = 0; t < num_tasks; ++t) { - task_is_currently_present_[t] = helper_->IsPresent(t); - } - precedences_->ComputePrecedences(helper_->EndVars(), - task_is_currently_present_, &before_); + index_to_end_vars_.clear(); + index_to_task_.clear(); + index_to_cached_shifted_start_min_.clear(); + for (const auto task_time : helper_->TaskByIncreasingShiftedStartMin()) { + const int task = task_time.task_index; + if (!helper_->IsPresent(task)) continue; + + index_to_task_.push_back(task); + index_to_end_vars_.push_back(helper_->EndVars()[task]); + index_to_cached_shifted_start_min_.push_back(task_time.time); + } + precedences_->ComputePrecedences(index_to_end_vars_, &before_); - // We don't care about the initial content of this vector. - task_to_arc_index_.resize(num_tasks); int critical_index; const int size = before_.size(); for (int i = 0; i < size;) { @@ -489,14 +614,17 @@ bool DisjunctivePrecedences::Propagate() { const int initial_i = i; IntegerValue min_offset = before_[i].offset; for (; i < size && before_[i].var == var; ++i) { - const int task = before_[i].index; + const int task = index_to_task_[before_[i].index]; min_offset = std::min(min_offset, before_[i].offset); + + // The task are actually in sorted order, so we do not need to call + // task_set_.Sort(). This property is DCHECKed. task_set_.AddUnsortedEntry( - {task, helper_->ShiftedStartMin(task), helper_->DurationMin(task)}); + {task, index_to_cached_shifted_start_min_[before_[i].index], + helper_->DurationMin(task)}); } DCHECK_GE(task_set_.SortedTasks().size(), 2); if (integer_trail_->IsCurrentlyIgnored(var)) continue; - task_set_.Sort(); // TODO(user): Only use the min_offset of the critical task? Or maybe do a // more general computation to find by how much we can push var? @@ -508,8 +636,10 @@ bool DisjunctivePrecedences::Propagate() { helper_->ClearReason(); // Fill task_to_arc_index_ since we need it for the reason. + // Note that we do not care about the initial content of this vector. for (int j = initial_i; j < i; ++j) { - task_to_arc_index_[before_[j].index] = before_[j].arc_index; + const int task = index_to_task_[before_[j].index]; + task_to_arc_index_[task] = before_[j].arc_index; } const IntegerValue window_start = sorted_tasks[critical_index].start_min; @@ -538,40 +668,126 @@ bool DisjunctivePrecedences::Propagate() { int DisjunctivePrecedences::RegisterWith(GenericLiteralWatcher* watcher) { // This propagator reach the fixed point in one go. const int id = watcher->Register(this); - helper_->WatchAllTasks(id, watcher); + helper_->SetTimeDirection(time_direction_); + helper_->WatchAllTasks(id, watcher, /*watch_start_max=*/false, + /*watch_end_max=*/false); return id; } bool DisjunctiveNotLast::Propagate() { helper_->SetTimeDirection(time_direction_); + const auto& task_by_decreasing_start_max = helper_->TaskByDecreasingStartMax(); + const auto& task_by_increasing_shifted_start_min = + helper_->TaskByIncreasingShiftedStartMin(); - const int num_tasks = helper_->NumTasks(); - int queue_index = num_tasks - 1; + // Split problem into independent part. + // + // The situation is trickier here, and we use two windows: + // - The classical "start_min_window_" as in the other propagator. + // - A second window, that includes all the task with a start_max inside + // [window_start, window_end]. + // + // Now, a task from the second window can be detected to be "not last" by only + // looking at the task in the first window. Tasks to the left do not cause + // issue for the task to be last, and tasks to the right will not lower the + // end-min of the task under consideration. + int queue_index = task_by_decreasing_start_max.size() - 1; + const int num_tasks = task_by_increasing_shifted_start_min.size(); + for (int i = 0; i < num_tasks;) { + start_min_window_.clear(); + IntegerValue window_end = kMinIntegerValue; + for (; i < num_tasks; ++i) { + const TaskTime task_time = task_by_increasing_shifted_start_min[i]; + const int task = task_time.task_index; + if (!helper_->IsPresent(task)) continue; + + const IntegerValue start_min = task_time.time; + if (start_min_window_.empty()) { + start_min_window_.push_back(task_time); + window_end = start_min + helper_->DurationMin(task); + } else if (start_min < window_end) { + start_min_window_.push_back(task_time); + window_end += helper_->DurationMin(task); + } else { + break; + } + } + + // Add to start_max_window_ all the task whose start_max + // fall into [window_start, window_end). + start_max_window_.clear(); + for (; queue_index >= 0; queue_index--) { + const auto task_time = task_by_decreasing_start_max[queue_index]; + + // Note that we add task whose presence is still unknown here. + if (task_time.time >= window_end) break; + if (helper_->IsAbsent(task_time.task_index)) continue; + start_max_window_.push_back(task_time); + } + + // If this is the case, we cannot propagate more than the detectable + // precedence propagator. Note that this continue must happen after we + // computed start_max_window_ though. + if (start_min_window_.size() <= 1) continue; + + // Process current window. + if (!start_max_window_.empty() && !PropagateSubwindow()) { + return false; + } + } + return true; +} + +bool DisjunctiveNotLast::PropagateSubwindow() { + auto& task_by_increasing_end_max = start_max_window_; + for (TaskTime& entry : task_by_increasing_end_max) { + entry.time = helper_->EndMax(entry.task_index); + } + IncrementalSort(task_by_increasing_end_max.begin(), + task_by_increasing_end_max.end()); + + const IntegerValue threshold = task_by_increasing_end_max.back().time; + auto& task_by_increasing_start_max = start_min_window_; + int queue_size = 0; + for (const TaskTime entry : task_by_increasing_start_max) { + const int task = entry.task_index; + const IntegerValue start_max = helper_->StartMax(task); + DCHECK(helper_->IsPresent(task)); + if (start_max < threshold) { + task_by_increasing_start_max[queue_size++] = {task, start_max}; + } + } + + // If the size is one, we cannot propagate more than the detectable precedence + // propagator. + if (queue_size <= 1) return true; + + task_by_increasing_start_max.resize(queue_size); + std::sort(task_by_increasing_start_max.begin(), + task_by_increasing_start_max.end()); task_set_.Clear(); - const auto& task_by_increasing_end_max = - ::gtl::reversed_view(helper_->TaskByDecreasingEndMax()); + int queue_index = 0; for (const auto task_time : task_by_increasing_end_max) { const int t = task_time.task_index; const IntegerValue end_max = task_time.time; - - if (helper_->IsAbsent(t)) continue; + DCHECK(!helper_->IsAbsent(t)); // task_set_ contains all the tasks that must start before the end-max of t. // These are the only candidates that have a chance to decrease the end-max // of t. - while (queue_index >= 0) { - const auto to_insert = task_by_decreasing_start_max[queue_index]; - const int task_index = to_insert.task_index; + while (queue_index < queue_size) { + const auto to_insert = task_by_increasing_start_max[queue_index]; const IntegerValue start_max = to_insert.time; if (end_max <= start_max) break; - if (helper_->IsPresent(task_index)) { - task_set_.AddEntry({task_index, helper_->ShiftedStartMin(task_index), - helper_->DurationMin(task_index)}); - } - --queue_index; + + const int task_index = to_insert.task_index; + DCHECK(helper_->IsPresent(task_index)); + task_set_.AddEntry({task_index, helper_->ShiftedStartMin(task_index), + helper_->DurationMin(task_index)}); + ++queue_index; } // In the following case, task t cannot be after all the critical tasks @@ -814,7 +1030,9 @@ bool DisjunctiveEdgeFinding::Propagate() { int DisjunctiveEdgeFinding::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); - helper_->WatchAllTasks(id, watcher); + helper_->SetTimeDirection(time_direction_); + helper_->WatchAllTasks(id, watcher, /*watch_start_max=*/false, + /*watch_end_max=*/true); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); return id; } diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index fb585830cf..36ab91ee10 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -72,9 +72,12 @@ class TaskSet { optimized_restart_ = 0; } void AddEntry(const Entry& e); - void NotifyEntryIsNowLastIfPresent(const Entry& e); void RemoveEntryWithIndex(int index); + // Advanced usage, if the entry is present, this assumes that its start_min is + // >= the end min without it, and update the datastructure accordingly. + void NotifyEntryIsNowLastIfPresent(const Entry& e); + // Advanced usage. Instead of calling many AddEntry(), it is more efficient to // call AddUnsortedEntry() instead, but then Sort() MUST be called just after // the insertions. Nothing is checked here, so it is up to the client to do @@ -132,13 +135,16 @@ class DisjunctiveOverloadChecker : public PropagatorInterface { int RegisterWith(GenericLiteralWatcher* watcher); private: + bool PropagateSubwindow(IntegerValue global_window_end); + const bool time_direction_; SchedulingConstraintHelper* helper_; + std::vector window_; + std::vector task_by_increasing_end_max_; + ThetaLambdaTree theta_tree_; std::vector task_to_event_; - std::vector event_to_task_; - std::vector event_time_; }; class DisjunctiveDetectablePrecedences : public PropagatorInterface { @@ -152,6 +158,11 @@ class DisjunctiveDetectablePrecedences : public PropagatorInterface { int RegisterWith(GenericLiteralWatcher* watcher); private: + bool PropagateSubwindow(IntegerValue max_end_min); + + std::vector window_; + std::vector task_by_increasing_start_max_; + const bool time_direction_; SchedulingConstraintHelper* helper_; TaskSet task_set_; @@ -167,6 +178,11 @@ class DisjunctiveNotLast : public PropagatorInterface { int RegisterWith(GenericLiteralWatcher* watcher); private: + bool PropagateSubwindow(); + + std::vector start_min_window_; + std::vector start_max_window_; + const bool time_direction_; SchedulingConstraintHelper* helper_; TaskSet task_set_; @@ -207,7 +223,7 @@ class DisjunctivePrecedences : public PropagatorInterface { integer_trail_(integer_trail), precedences_(precedences), task_set_(helper->NumTasks()), - task_is_currently_present_(helper->NumTasks()) {} + task_to_arc_index_(helper->NumTasks()) {} bool Propagate() final; int RegisterWith(GenericLiteralWatcher* watcher); @@ -217,8 +233,11 @@ class DisjunctivePrecedences : public PropagatorInterface { IntegerTrail* integer_trail_; PrecedencesPropagator* precedences_; + std::vector index_to_end_vars_; + std::vector index_to_task_; + std::vector index_to_cached_shifted_start_min_; + TaskSet task_set_; - std::vector task_is_currently_present_; std::vector task_to_arc_index_; std::vector before_; }; diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index 78ca61cea4..f4364add94 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -51,12 +51,23 @@ IntervalVariable IntervalsRepository::CreateInterval(IntegerVariable start, SchedulingConstraintHelper::SchedulingConstraintHelper( const std::vector& tasks, Model* model) - : trail_(model->GetOrCreate()), + : repository_(model->GetOrCreate()), + trail_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), precedences_(model->GetOrCreate()), - current_time_direction_(true), - visible_intervals_(tasks.size(), true) { - auto* repository = model->GetOrCreate(); + current_time_direction_(true) { + Init(tasks); +} + +SchedulingConstraintHelper::SchedulingConstraintHelper(Model* model) + : repository_(model->GetOrCreate()), + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + precedences_(model->GetOrCreate()), + current_time_direction_(true) {} + +void SchedulingConstraintHelper::Init( + const std::vector& tasks) { start_vars_.clear(); end_vars_.clear(); minus_end_vars_.clear(); @@ -65,22 +76,22 @@ SchedulingConstraintHelper::SchedulingConstraintHelper( fixed_durations_.clear(); reason_for_presence_.clear(); for (const IntervalVariable i : tasks) { - if (repository->IsOptional(i)) { - reason_for_presence_.push_back(repository->IsPresentLiteral(i).Index()); + if (repository_->IsOptional(i)) { + reason_for_presence_.push_back(repository_->IsPresentLiteral(i).Index()); } else { reason_for_presence_.push_back(kNoLiteralIndex); } - if (repository->SizeVar(i) == kNoIntegerVariable) { + if (repository_->SizeVar(i) == kNoIntegerVariable) { duration_vars_.push_back(kNoIntegerVariable); - fixed_durations_.push_back(repository->MinSize(i)); + fixed_durations_.push_back(repository_->MinSize(i)); } else { - duration_vars_.push_back(repository->SizeVar(i)); + duration_vars_.push_back(repository_->SizeVar(i)); fixed_durations_.push_back(IntegerValue(0)); } - start_vars_.push_back(repository->StartVar(i)); - end_vars_.push_back(repository->EndVar(i)); - minus_start_vars_.push_back(NegationOf(repository->StartVar(i))); - minus_end_vars_.push_back(NegationOf(repository->EndVar(i))); + start_vars_.push_back(repository_->StartVar(i)); + end_vars_.push_back(repository_->EndVar(i)); + minus_start_vars_.push_back(NegationOf(repository_->StartVar(i))); + minus_end_vars_.push_back(NegationOf(repository_->EndVar(i))); } const int num_tasks = start_vars_.size(); @@ -112,7 +123,7 @@ void SchedulingConstraintHelper::SetTimeDirection(bool is_forward) { task_by_decreasing_shifted_end_max_); } -const std::vector& +const std::vector& SchedulingConstraintHelper::TaskByIncreasingStartMin() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -124,7 +135,7 @@ SchedulingConstraintHelper::TaskByIncreasingStartMin() { return task_by_increasing_min_start_; } -const std::vector& +const std::vector& SchedulingConstraintHelper::TaskByIncreasingEndMin() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -136,7 +147,7 @@ SchedulingConstraintHelper::TaskByIncreasingEndMin() { return task_by_increasing_min_end_; } -const std::vector& +const std::vector& SchedulingConstraintHelper::TaskByDecreasingStartMax() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -149,7 +160,7 @@ SchedulingConstraintHelper::TaskByDecreasingStartMax() { return task_by_decreasing_max_start_; } -const std::vector& +const std::vector& SchedulingConstraintHelper::TaskByDecreasingEndMax() { const int num_tasks = NumTasks(); for (int i = 0; i < num_tasks; ++i) { @@ -161,13 +172,18 @@ SchedulingConstraintHelper::TaskByDecreasingEndMax() { return task_by_decreasing_max_end_; } -const std::vector& +const std::vector& SchedulingConstraintHelper::TaskByIncreasingShiftedStartMin() { const int num_tasks = NumTasks(); + bool is_sorted = true; + IntegerValue previous = kMinIntegerValue; for (int i = 0; i < num_tasks; ++i) { TaskTime& ref = task_by_increasing_shifted_start_min_[i]; ref.time = ShiftedStartMin(ref.task_index); + is_sorted = is_sorted && ref.time >= previous; + previous = ref.time; } + if (is_sorted) return task_by_increasing_shifted_start_min_; IncrementalSort(task_by_increasing_shifted_start_min_.begin(), task_by_increasing_shifted_start_min_.end()); return task_by_increasing_shifted_start_min_; @@ -269,12 +285,20 @@ bool SchedulingConstraintHelper::ReportConflict() { return integer_trail_->ReportConflict(literal_reason_, integer_reason_); } -void SchedulingConstraintHelper::WatchAllTasks( - int id, GenericLiteralWatcher* watcher) const { +void SchedulingConstraintHelper::WatchAllTasks(int id, + GenericLiteralWatcher* watcher, + bool watch_start_max, + bool watch_end_max) const { const int num_tasks = start_vars_.size(); for (int t = 0; t < num_tasks; ++t) { - watcher->WatchIntegerVariable(start_vars_[t], id); - watcher->WatchIntegerVariable(end_vars_[t], id); + watcher->WatchLowerBound(start_vars_[t], id); + watcher->WatchLowerBound(end_vars_[t], id); + if (watch_start_max) { + watcher->WatchUpperBound(start_vars_[t], id); + } + if (watch_end_max) { + watcher->WatchUpperBound(end_vars_[t], id); + } if (duration_vars_[t] != kNoIntegerVariable) { watcher->WatchLowerBound(duration_vars_[t], id); } @@ -305,17 +329,5 @@ void SchedulingConstraintHelper::ImportOtherReasons( other_helper.integer_reason_.end()); } -void SchedulingConstraintHelper::SetAllIntervalsVisible() { - visible_intervals_.assign(NumTasks(), true); -} - -void SchedulingConstraintHelper::SetVisibleIntervals( - const std::vector& visible_intervals) { - visible_intervals_.assign(NumTasks(), false); - for (const int t : visible_intervals) { - visible_intervals_[t] = true; - } -} - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index 6f5c86e000..4f99ec6d3e 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -109,6 +109,16 @@ class IntervalsRepository { DISALLOW_COPY_AND_ASSIGN(IntervalsRepository); }; +// An helper struct to sort task by time. This is used by the +// SchedulingConstraintHelper but also by many scheduling propagators to sort +// tasks. +struct TaskTime { + int task_index; + IntegerValue time; + bool operator<(TaskTime other) const { return time < other.time; } + bool operator>(TaskTime other) const { return time > other.time; } +}; + // Helper class shared by the propagators that manage a given list of tasks. // // One of the main advantage of this class is that it allows to share the @@ -116,11 +126,17 @@ class IntervalsRepository { // code. class SchedulingConstraintHelper { public: + SchedulingConstraintHelper(Model* model); + // All the functions below refer to a task by its index t in the tasks // vector given at construction. SchedulingConstraintHelper(const std::vector& tasks, Model* model); + // Resets the class to the same state as if it was constructed with + // the given set of tasks. + void Init(const std::vector& tasks); + // Returns the number of task. int NumTasks() const { return start_vars_.size(); } @@ -174,12 +190,6 @@ class SchedulingConstraintHelper { // // TODO(user): we could merge the first loop of IncrementalSort() with the // loop that fill TaskTime.time at each call. - struct TaskTime { - int task_index; - IntegerValue time; - bool operator<(TaskTime other) const { return time < other.time; } - bool operator>(TaskTime other) const { return time > other.time; } - }; const std::vector& TaskByIncreasingStartMin(); const std::vector& TaskByIncreasingEndMin(); const std::vector& TaskByDecreasingStartMax(); @@ -229,7 +239,9 @@ class SchedulingConstraintHelper { // Registers the given propagator id to be called if any of the tasks // in this class change. - void WatchAllTasks(int id, GenericLiteralWatcher* watcher) const; + void WatchAllTasks(int id, GenericLiteralWatcher* watcher, + bool watch_start_max = true, + bool watch_end_max = true) const; // Manages the other helper (used by the diffn constraint). // @@ -251,11 +263,6 @@ class SchedulingConstraintHelper { // This is used in the 2D energetic reasoning in the diffn constraint. void ImportOtherReasons(const SchedulingConstraintHelper& other_helper); - // Manages the visibility of intervals. When marked as invisible, IsPresent() - // will always return false, and IsAbsent() will always return true. - void SetAllIntervalsVisible(); - void SetVisibleIntervals(const std::vector& visible_intervals); - private: // Internal function for IncreaseStartMin()/DecreaseEndMax(). bool PushIntervalBound(int t, IntegerLiteral lit); @@ -269,10 +276,7 @@ class SchedulingConstraintHelper { // Import the reasons on the other helper into this helper. void ImportOtherReasons(); - // Returns true if the interval is visible. Note that this method always - // return true if SetVisibleIntervals() has never been called. - bool IsVisible(int t) const { return visible_intervals_[t]; } - + IntervalsRepository* repository_; Trail* trail_; IntegerTrail* integer_trail_; PrecedencesPropagator* precedences_; @@ -310,9 +314,6 @@ class SchedulingConstraintHelper { SchedulingConstraintHelper* other_helper_ = nullptr; IntegerValue event_for_other_helper_; std::vector already_added_to_other_reasons_; - - // Extra filter on the helper. Only non ignored intervals are even looked at. - std::vector visible_intervals_; }; // ============================================================================= @@ -361,18 +362,15 @@ inline bool SchedulingConstraintHelper::EndIsFixed(int t) const { } inline bool SchedulingConstraintHelper::IsOptional(int t) const { - if (!IsVisible(t)) return false; return reason_for_presence_[t] != kNoLiteralIndex; } inline bool SchedulingConstraintHelper::IsPresent(int t) const { - if (!IsVisible(t)) return false; if (reason_for_presence_[t] == kNoLiteralIndex) return true; return trail_->Assignment().LiteralIsTrue(Literal(reason_for_presence_[t])); } inline bool SchedulingConstraintHelper::IsAbsent(int t) const { - if (!IsVisible(t)) return true; if (reason_for_presence_[t] == kNoLiteralIndex) return false; return trail_->Assignment().LiteralIsFalse(Literal(reason_for_presence_[t])); } @@ -387,7 +385,6 @@ inline void SchedulingConstraintHelper::ClearReason() { } inline void SchedulingConstraintHelper::AddPresenceReason(int t) { - DCHECK(IsVisible(t)); AddOtherReason(t); if (reason_for_presence_[t] != kNoLiteralIndex) { literal_reason_.push_back(Literal(reason_for_presence_[t]).Negated()); @@ -395,7 +392,6 @@ inline void SchedulingConstraintHelper::AddPresenceReason(int t) { } inline void SchedulingConstraintHelper::AddDurationMinReason(int t) { - DCHECK(IsVisible(t)); AddOtherReason(t); if (duration_vars_[t] != kNoIntegerVariable) { integer_reason_.push_back( @@ -405,7 +401,6 @@ inline void SchedulingConstraintHelper::AddDurationMinReason(int t) { inline void SchedulingConstraintHelper::AddDurationMinReason( int t, IntegerValue lower_bound) { - DCHECK(IsVisible(t)); AddOtherReason(t); if (duration_vars_[t] != kNoIntegerVariable) { DCHECK_GE(DurationMin(t), lower_bound); @@ -416,7 +411,6 @@ inline void SchedulingConstraintHelper::AddDurationMinReason( inline void SchedulingConstraintHelper::AddStartMinReason( int t, IntegerValue lower_bound) { - DCHECK(IsVisible(t)); DCHECK_GE(StartMin(t), lower_bound); AddOtherReason(t); integer_reason_.push_back( @@ -425,7 +419,6 @@ inline void SchedulingConstraintHelper::AddStartMinReason( inline void SchedulingConstraintHelper::AddStartMaxReason( int t, IntegerValue upper_bound) { - DCHECK(IsVisible(t)); DCHECK_LE(StartMax(t), upper_bound); AddOtherReason(t); integer_reason_.push_back( @@ -434,7 +427,6 @@ inline void SchedulingConstraintHelper::AddStartMaxReason( inline void SchedulingConstraintHelper::AddEndMinReason( int t, IntegerValue lower_bound) { - DCHECK(IsVisible(t)); DCHECK_GE(EndMin(t), lower_bound); AddOtherReason(t); integer_reason_.push_back( @@ -443,7 +435,6 @@ inline void SchedulingConstraintHelper::AddEndMinReason( inline void SchedulingConstraintHelper::AddEndMaxReason( int t, IntegerValue upper_bound) { - DCHECK(IsVisible(t)); DCHECK_LE(EndMax(t), upper_bound); AddOtherReason(t); integer_reason_.push_back( @@ -452,7 +443,6 @@ inline void SchedulingConstraintHelper::AddEndMaxReason( inline void SchedulingConstraintHelper::AddEnergyAfterReason( int t, IntegerValue energy_min, IntegerValue time) { - DCHECK(IsVisible(t)); if (StartMin(t) >= time) { AddStartMinReason(t, time); } else { diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index bc020c58a5..5b51f6deb3 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -125,14 +125,13 @@ void PrecedencesPropagator::Untrail(const Trail& trail, int trail_index) { // permutation. void PrecedencesPropagator::ComputePrecedences( const std::vector& vars, - const std::vector& to_consider, std::vector* output) { tmp_sorted_vars_.clear(); tmp_precedences_.clear(); for (int index = 0; index < vars.size(); ++index) { const IntegerVariable var = vars[index]; CHECK_NE(kNoIntegerVariable, var); - if (!to_consider[index] || var >= impacted_arcs_.size()) continue; + if (var >= impacted_arcs_.size()) continue; for (const ArcIndex arc_index : impacted_arcs_[var]) { const ArcInfo& arc = arcs_[arc_index]; if (integer_trail_->IsCurrentlyIgnored(arc.head_var)) continue; diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index caffb8c75e..049d932ad3 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -101,14 +101,15 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { // Note that the IntegerVariable in the vector are also returned in // topological order for a more efficient propagation in // DisjunctivePrecedences::Propagate() where this is used. + // + // Important: For identical vars, the entry are sorted by index. struct IntegerPrecedences { int index; // position in vars. IntegerVariable var; // An IntegerVariable that is >= to vars[index]. int arc_index; // Used by AddPrecedenceReason(). - IntegerValue offset; // we have: input_vars[index] + offset <= var + IntegerValue offset; // we have: vars[index] + offset <= var }; void ComputePrecedences(const std::vector& vars, - const std::vector& to_consider, std::vector* output); void AddPrecedenceReason(int arc_index, IntegerValue min_offset, std::vector* literal_reason, diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 5a6ef62cea..9df461a1d9 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -1496,6 +1496,10 @@ class CpSolver(object): """Returns some statistics on the solution found as a string.""" return pywrapsat.SatHelper.SolverResponseStats(self.__solution) + def ResponseProto(self): + """Returns the response object.""" + return self.__solution + class CpSolverSolutionCallback(pywrapsat.SolutionCallback): """Solution callback. diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 3d42aea9e5..29512ca408 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -892,7 +892,7 @@ void SatSolver::ClearNewlyAddedBinaryClauses() { namespace { // Return the next value that is a multiple of interval. -int NextMultipleOf(int64 value, int64 interval) { +int64 NextMultipleOf(int64 value, int64 interval) { return interval * (1 + value / interval); } } // namespace @@ -1095,14 +1095,15 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit) { parameters_->minimize_with_propagation_restart_period(); // Variables used to show the search progress. - const int kDisplayFrequency = 10000; - int next_display = parameters_->log_search_progress() - ? NextMultipleOf(num_failures(), kDisplayFrequency) - : std::numeric_limits::max(); + const int64 kDisplayFrequency = 10000; + int64 next_display = parameters_->log_search_progress() + ? NextMultipleOf(num_failures(), kDisplayFrequency) + : std::numeric_limits::max(); // Variables used to check the memory limit every kMemoryCheckFrequency. - const int kMemoryCheckFrequency = 10000; - int next_memory_check = NextMultipleOf(num_failures(), kMemoryCheckFrequency); + const int64 kMemoryCheckFrequency = 10000; + int64 next_memory_check = + NextMultipleOf(num_failures(), kMemoryCheckFrequency); // The max_number_of_conflicts is per solve but the counter is for the whole // solver. diff --git a/ortools/sat/timetable.h b/ortools/sat/timetable.h index 7f0f891813..8af9930759 100644 --- a/ortools/sat/timetable.h +++ b/ortools/sat/timetable.h @@ -125,8 +125,8 @@ class TimeTablingPerTask : public PropagatorInterface { RevRepository rev_repository_integer_value_; // Vector of tasks sorted by maximum starting (resp. minimum ending) time. - std::vector by_start_max_; - std::vector by_end_min_; + std::vector by_start_max_; + std::vector by_end_min_; // Tasks contained in the range [left_start_, right_start_) of by_start_max_ // must be sorted and considered when building the profile. The state of these