From 198d6295e49b3017b799ec87eef5a74dc0048fd2 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Thu, 23 May 2019 18:55:20 +0200 Subject: [PATCH] optimize speed and memory --- ortools/sat/cp_model_presolve.cc | 58 ++++++++++---------- ortools/sat/linear_programming_constraint.cc | 16 ++---- ortools/sat/util.cc | 33 +++++++++++ ortools/sat/util.h | 27 +++++++++ ortools/util/sorted_interval_list.cc | 13 +++-- ortools/util/sorted_interval_list.h | 7 ++- 6 files changed, 109 insertions(+), 45 deletions(-) diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 3746aa13e3..69849af9df 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -131,12 +131,19 @@ bool PresolveContext::VariableIsUniqueAndRemovable(int ref) const { } Domain PresolveContext::DomainOf(int ref) const { - if (RefIsPositive(ref)) return domains[ref]; - return domains[PositiveRef(ref)].Negation(); + Domain result; + if (RefIsPositive(ref)) { + result = domains[ref]; + } else { + result = domains[PositiveRef(ref)].Negation(); + } + return result; } bool PresolveContext::DomainContains(int ref, int64 value) const { - if (!RefIsPositive(ref)) return DomainContains(NegatedRef(ref), -value); + if (!RefIsPositive(ref)) { + return domains[PositiveRef(ref)].Contains(-value); + } return domains[ref].Contains(value); } @@ -797,9 +804,9 @@ bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) { // infered_domain ∩ [kint64min, target_ub] ⊂ target_domain // then the constraint is really max(...) <= target_ub and we can simplify it. if (context->VariableIsUniqueAndRemovable(target_ref)) { - const Domain target_domain = context->DomainOf(target_ref); + const Domain& target_domain = context->DomainOf(target_ref); if (infered_domain.IntersectionWith(Domain(kint64min, target_domain.Max())) - .IsIncludedIn(context->DomainOf(target_ref))) { + .IsIncludedIn(target_domain)) { if (infered_domain.Max() <= target_domain.Max()) { // The constraint is always satisfiable. context->UpdateRuleStats("int_max: always true"); @@ -1783,21 +1790,19 @@ bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) { if (!ct->enforcement_literal().empty()) return false; bool changed = false; + const Domain start_domain = context->DomainOf(start); + const Domain end_domain = context->DomainOf(end); + const Domain size_domain = context->DomainOf(size); + if (!context->IntersectDomainWith(end, start_domain.AdditionWith(size_domain), + &changed)) { + return false; + } if (!context->IntersectDomainWith( - end, context->DomainOf(start).AdditionWith(context->DomainOf(size)), - &changed)) { + start, end_domain.AdditionWith(size_domain.Negation()), &changed)) { return false; } - if (!context->IntersectDomainWith(start, - context->DomainOf(end).AdditionWith( - context->DomainOf(size).Negation()), - &changed)) { - return false; - } - if (!context->IntersectDomainWith(size, - context->DomainOf(end).AdditionWith( - context->DomainOf(start).Negation()), - &changed)) { + if (!context->IntersectDomainWith( + size, end_domain.AdditionWith(start_domain.Negation()), &changed)) { return false; } if (changed) { @@ -1844,14 +1849,14 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { // Filter possible index values. Accumulate variable domains to build // a possible target domain. Domain infered_domain; - const Domain initial_index_domain = context->DomainOf(index_ref); - Domain target_domain = context->DomainOf(target_ref); + const Domain& initial_index_domain = context->DomainOf(index_ref); + const Domain& target_domain = context->DomainOf(target_ref); for (const ClosedInterval interval : initial_index_domain) { for (int value = interval.start; value <= interval.end; ++value) { CHECK_GE(value, 0); CHECK_LT(value, ct->element().vars_size()); const int ref = ct->element().vars(value); - const Domain domain = context->DomainOf(ref); + const Domain& domain = context->DomainOf(ref); if (domain.IntersectionWith(target_domain).IsEmpty()) { bool domain_modified = false; if (!context->IntersectDomainWith( @@ -2000,7 +2005,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { // Eventually, domain sizes will be synchronized. bool changed_values = false; std::vector valid_index_values; - const Domain index_domain = context->DomainOf(index_ref); + const Domain& index_domain = context->DomainOf(index_ref); for (const ClosedInterval interval : index_domain) { for (int i = interval.start; i <= interval.end; ++i) { const int64 value = context->MinOf(ct->element().vars(i)); @@ -2035,10 +2040,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } if (context->IsFixed(target_ref)) { - const Domain index_domain = context->DomainOf(index_ref); const int64 target_value = context->MinOf(target_ref); - - for (const ClosedInterval& interval : index_domain) { + for (const ClosedInterval& interval : context->DomainOf(index_ref)) { for (int64 v = interval.start; v <= interval.end; ++v) { const int var = ct->element().vars(v); const int index_lit = @@ -2054,8 +2057,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { if (target_ref == index_ref) { // Filter impossible index values. - Domain index_domain = context->DomainOf(index_ref); std::vector possible_indices; + const Domain& index_domain = context->DomainOf(index_ref); for (const ClosedInterval& interval : index_domain) { for (int64 value = interval.start; value <= interval.end; ++value) { const int ref = ct->element().vars(value); @@ -2077,7 +2080,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } if (all_constants) { - const Domain index_domain = context->DomainOf(index_ref); absl::flat_hash_map> supports; // Help linearization. @@ -2087,7 +2089,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { lin->add_coeffs(-1); int64 rhs = 0; - for (const ClosedInterval& interval : index_domain) { + for (const ClosedInterval& interval : context->DomainOf(index_ref)) { for (int64 v = interval.start; v <= interval.end; ++v) { const int64 value = context->MinOf(ct->element().vars(v)); const int index_lit = @@ -2184,7 +2186,7 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { v = inverse_value; } tuple[j] = v; - if (!context->DomainOf(r.representative).Contains(v)) { + if (!context->DomainContains(r.representative, v)) { delete_row = true; break; } diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index 75d7a5d09a..4ef09f4b8d 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -637,26 +637,22 @@ void LinearProgrammingConstraint::UpdateSimplexIterationLimit( if (sat_parameters_.linearization_level() < 2) return; const int64 num_degenerate_columns = CalculateDegeneracy(); const int64 num_cols = simplex_.GetProblemNumCols().value(); - const bool high_degeneracy = num_degenerate_columns >= 0.5 * num_cols; - const bool medium_degeneracy = num_degenerate_columns >= 0.3 * num_cols; + const bool is_degenerate = num_degenerate_columns >= 0.3 * num_cols; + const int64 decrease_factor = 10 * num_degenerate_columns / num_cols; if (simplex_.GetProblemStatus() == glop::ProblemStatus::DUAL_FEASIBLE) { // We reached here probably because we predicted wrong. We use this as a // signal to increase the iterations or punish less for degeneracy compare // to the other part. // TODO(user): Derive a formula to update the limit using degeneracy to // simplify the code. - if (high_degeneracy) { - next_simplex_iter_ /= 5; - } else if (medium_degeneracy) { - next_simplex_iter_ /= 2; + if (is_degenerate) { + next_simplex_iter_ /= decrease_factor; } else { next_simplex_iter_ *= 2; } } else if (simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL) { - if (high_degeneracy) { - next_simplex_iter_ /= 10; - } else if (medium_degeneracy) { - next_simplex_iter_ /= 5; + if (is_degenerate) { + next_simplex_iter_ /= 2 * decrease_factor; } else { // This is the most common case. We use the size of the problem to // determine the limit and ignore the previous limit. diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index 86810dac40..902cacc0c1 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -14,6 +14,7 @@ #include "ortools/sat/util.h" #include +#include namespace operations_research { namespace sat { @@ -74,5 +75,37 @@ void ExponentialMovingAverage::AddData(double new_record) { : (new_record + decaying_factor_ * (average_ - new_record)); } +void Percentile::AddRecord(double record) { + records_.push_front(record); + if (records_.size() > record_limit_) { + records_.pop_back(); + } +} + +double Percentile::GetPercentile(double percent) { + CHECK_GT(records_.size(), 0); + CHECK_LE(percent, 100.0); + CHECK_GE(percent, 0.0); + std::vector sorted_records(records_.begin(), records_.end()); + std::sort(sorted_records.begin(), sorted_records.end()); + const int num_records = sorted_records.size(); + + const double percentile_rank = + static_cast(num_records) * percent / 100.0 - 0.5; + if (percentile_rank <= 0) { + return sorted_records.front(); + } else if (percentile_rank >= num_records - 1) { + return sorted_records.back(); + } + // Interpolate. + DCHECK_GE(num_records, 2); + DCHECK_LT(percentile_rank, num_records - 1); + const int lower_rank = static_cast(std::floor(percentile_rank)); + DCHECK_LT(lower_rank, num_records - 1); + return sorted_records[lower_rank] + + (percentile_rank - lower_rank) * + (sorted_records[lower_rank + 1] - sorted_records[lower_rank]); +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/util.h b/ortools/sat/util.h index a1b429d588..e0bcd5ea1c 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -14,6 +14,8 @@ #ifndef OR_TOOLS_SAT_UTIL_H_ #define OR_TOOLS_SAT_UTIL_H_ +#include + #include "ortools/base/random.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -144,6 +146,31 @@ class ExponentialMovingAverage { const double decaying_factor_; }; +// Utility to calculate percentile (First variant) for limited number of +// records. Reference: https://en.wikipedia.org/wiki/Percentile +// +// After the vector is sorted, we assume that the element with index i +// correspond to the percentile 100*(i+0.5)/size. For percentiles before the +// first element (resp. after the last one) we return the first element (resp. +// the last). And otherwise we do a linear interpolation between the two element +// around the asked percentile. +class Percentile { + public: + explicit Percentile(int record_limit) : record_limit_(record_limit) {} + + void AddRecord(double record); + + // Returns number of stored records. + int64 NumRecords() const { return records_.size(); } + + // Note that this is not fast and runs in O(n log n) for n records. + double GetPercentile(double percent); + + private: + std::deque records_; + const int record_limit_; +}; + } // namespace sat } // namespace operations_research diff --git a/ortools/util/sorted_interval_list.cc b/ortools/util/sorted_interval_list.cc index 2531dabdd7..8974cc7c53 100644 --- a/ortools/util/sorted_interval_list.cc +++ b/ortools/util/sorted_interval_list.cc @@ -185,12 +185,15 @@ int64 Domain::Max() const { return intervals_.back().end; } -// TODO(user): binary search if size is large? bool Domain::Contains(int64 value) const { - for (const ClosedInterval& interval : intervals_) { - if (interval.start <= value && interval.end >= value) return true; - } - return false; + // Because we only compare by start and there is no duplicate starts, this + // should be the next interval after the one that has a chance to contains + // value. + auto it = std::upper_bound(intervals_.begin(), intervals_.end(), + ClosedInterval(value, value)); + if (it == intervals_.begin()) return false; + --it; + return value <= it->end; } bool Domain::IsIncludedIn(const Domain& domain) const { diff --git a/ortools/util/sorted_interval_list.h b/ortools/util/sorted_interval_list.h index 3ca07a39e4..6a3052f41d 100644 --- a/ortools/util/sorted_interval_list.h +++ b/ortools/util/sorted_interval_list.h @@ -27,8 +27,8 @@ namespace operations_research { // Represents a closed interval [start, end]. We must have start <= end. struct ClosedInterval { - int64 start; // Inclusive. - int64 end; // Inclusive. + ClosedInterval() {} + ClosedInterval(int64 s, int64 e) : start(s), end(e) {} std::string DebugString() const; bool operator==(const ClosedInterval& other) const { @@ -41,6 +41,9 @@ struct ClosedInterval { bool operator<(const ClosedInterval& other) const { return start < other.start; } + + int64 start = 0; // Inclusive. + int64 end = 0; // Inclusive. }; std::ostream& operator<<(std::ostream& out, const ClosedInterval& interval);