optimize speed and memory

This commit is contained in:
Laurent Perron
2019-05-23 18:55:20 +02:00
parent 71405ff5b1
commit 198d6295e4
6 changed files with 109 additions and 45 deletions

View File

@@ -131,12 +131,19 @@ bool PresolveContext::VariableIsUniqueAndRemovable(int ref) const {
}
Domain PresolveContext::DomainOf(int ref) const {
if (RefIsPositive(ref)) return domains[ref];
return domains[PositiveRef(ref)].Negation();
Domain result;
if (RefIsPositive(ref)) {
result = domains[ref];
} else {
result = domains[PositiveRef(ref)].Negation();
}
return result;
}
bool PresolveContext::DomainContains(int ref, int64 value) const {
if (!RefIsPositive(ref)) return DomainContains(NegatedRef(ref), -value);
if (!RefIsPositive(ref)) {
return domains[PositiveRef(ref)].Contains(-value);
}
return domains[ref].Contains(value);
}
@@ -797,9 +804,9 @@ bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) {
// infered_domain ∩ [kint64min, target_ub] ⊂ target_domain
// then the constraint is really max(...) <= target_ub and we can simplify it.
if (context->VariableIsUniqueAndRemovable(target_ref)) {
const Domain target_domain = context->DomainOf(target_ref);
const Domain& target_domain = context->DomainOf(target_ref);
if (infered_domain.IntersectionWith(Domain(kint64min, target_domain.Max()))
.IsIncludedIn(context->DomainOf(target_ref))) {
.IsIncludedIn(target_domain)) {
if (infered_domain.Max() <= target_domain.Max()) {
// The constraint is always satisfiable.
context->UpdateRuleStats("int_max: always true");
@@ -1783,21 +1790,19 @@ bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) {
if (!ct->enforcement_literal().empty()) return false;
bool changed = false;
const Domain start_domain = context->DomainOf(start);
const Domain end_domain = context->DomainOf(end);
const Domain size_domain = context->DomainOf(size);
if (!context->IntersectDomainWith(end, start_domain.AdditionWith(size_domain),
&changed)) {
return false;
}
if (!context->IntersectDomainWith(
end, context->DomainOf(start).AdditionWith(context->DomainOf(size)),
&changed)) {
start, end_domain.AdditionWith(size_domain.Negation()), &changed)) {
return false;
}
if (!context->IntersectDomainWith(start,
context->DomainOf(end).AdditionWith(
context->DomainOf(size).Negation()),
&changed)) {
return false;
}
if (!context->IntersectDomainWith(size,
context->DomainOf(end).AdditionWith(
context->DomainOf(start).Negation()),
&changed)) {
if (!context->IntersectDomainWith(
size, end_domain.AdditionWith(start_domain.Negation()), &changed)) {
return false;
}
if (changed) {
@@ -1844,14 +1849,14 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
// Filter possible index values. Accumulate variable domains to build
// a possible target domain.
Domain infered_domain;
const Domain initial_index_domain = context->DomainOf(index_ref);
Domain target_domain = context->DomainOf(target_ref);
const Domain& initial_index_domain = context->DomainOf(index_ref);
const Domain& target_domain = context->DomainOf(target_ref);
for (const ClosedInterval interval : initial_index_domain) {
for (int value = interval.start; value <= interval.end; ++value) {
CHECK_GE(value, 0);
CHECK_LT(value, ct->element().vars_size());
const int ref = ct->element().vars(value);
const Domain domain = context->DomainOf(ref);
const Domain& domain = context->DomainOf(ref);
if (domain.IntersectionWith(target_domain).IsEmpty()) {
bool domain_modified = false;
if (!context->IntersectDomainWith(
@@ -2000,7 +2005,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
// Eventually, domain sizes will be synchronized.
bool changed_values = false;
std::vector<int64> valid_index_values;
const Domain index_domain = context->DomainOf(index_ref);
const Domain& index_domain = context->DomainOf(index_ref);
for (const ClosedInterval interval : index_domain) {
for (int i = interval.start; i <= interval.end; ++i) {
const int64 value = context->MinOf(ct->element().vars(i));
@@ -2035,10 +2040,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
}
if (context->IsFixed(target_ref)) {
const Domain index_domain = context->DomainOf(index_ref);
const int64 target_value = context->MinOf(target_ref);
for (const ClosedInterval& interval : index_domain) {
for (const ClosedInterval& interval : context->DomainOf(index_ref)) {
for (int64 v = interval.start; v <= interval.end; ++v) {
const int var = ct->element().vars(v);
const int index_lit =
@@ -2054,8 +2057,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
if (target_ref == index_ref) {
// Filter impossible index values.
Domain index_domain = context->DomainOf(index_ref);
std::vector<int64> possible_indices;
const Domain& index_domain = context->DomainOf(index_ref);
for (const ClosedInterval& interval : index_domain) {
for (int64 value = interval.start; value <= interval.end; ++value) {
const int ref = ct->element().vars(value);
@@ -2077,7 +2080,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
}
if (all_constants) {
const Domain index_domain = context->DomainOf(index_ref);
absl::flat_hash_map<int, std::vector<int>> supports;
// Help linearization.
@@ -2087,7 +2089,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
lin->add_coeffs(-1);
int64 rhs = 0;
for (const ClosedInterval& interval : index_domain) {
for (const ClosedInterval& interval : context->DomainOf(index_ref)) {
for (int64 v = interval.start; v <= interval.end; ++v) {
const int64 value = context->MinOf(ct->element().vars(v));
const int index_lit =
@@ -2184,7 +2186,7 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) {
v = inverse_value;
}
tuple[j] = v;
if (!context->DomainOf(r.representative).Contains(v)) {
if (!context->DomainContains(r.representative, v)) {
delete_row = true;
break;
}

View File

@@ -637,26 +637,22 @@ void LinearProgrammingConstraint::UpdateSimplexIterationLimit(
if (sat_parameters_.linearization_level() < 2) return;
const int64 num_degenerate_columns = CalculateDegeneracy();
const int64 num_cols = simplex_.GetProblemNumCols().value();
const bool high_degeneracy = num_degenerate_columns >= 0.5 * num_cols;
const bool medium_degeneracy = num_degenerate_columns >= 0.3 * num_cols;
const bool is_degenerate = num_degenerate_columns >= 0.3 * num_cols;
const int64 decrease_factor = 10 * num_degenerate_columns / num_cols;
if (simplex_.GetProblemStatus() == glop::ProblemStatus::DUAL_FEASIBLE) {
// We reached here probably because we predicted wrong. We use this as a
// signal to increase the iterations or punish less for degeneracy compare
// to the other part.
// TODO(user): Derive a formula to update the limit using degeneracy to
// simplify the code.
if (high_degeneracy) {
next_simplex_iter_ /= 5;
} else if (medium_degeneracy) {
next_simplex_iter_ /= 2;
if (is_degenerate) {
next_simplex_iter_ /= decrease_factor;
} else {
next_simplex_iter_ *= 2;
}
} else if (simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL) {
if (high_degeneracy) {
next_simplex_iter_ /= 10;
} else if (medium_degeneracy) {
next_simplex_iter_ /= 5;
if (is_degenerate) {
next_simplex_iter_ /= 2 * decrease_factor;
} else {
// This is the most common case. We use the size of the problem to
// determine the limit and ignore the previous limit.

View File

@@ -14,6 +14,7 @@
#include "ortools/sat/util.h"
#include <algorithm>
#include <cmath>
namespace operations_research {
namespace sat {
@@ -74,5 +75,37 @@ void ExponentialMovingAverage::AddData(double new_record) {
: (new_record + decaying_factor_ * (average_ - new_record));
}
void Percentile::AddRecord(double record) {
records_.push_front(record);
if (records_.size() > record_limit_) {
records_.pop_back();
}
}
double Percentile::GetPercentile(double percent) {
CHECK_GT(records_.size(), 0);
CHECK_LE(percent, 100.0);
CHECK_GE(percent, 0.0);
std::vector<double> sorted_records(records_.begin(), records_.end());
std::sort(sorted_records.begin(), sorted_records.end());
const int num_records = sorted_records.size();
const double percentile_rank =
static_cast<double>(num_records) * percent / 100.0 - 0.5;
if (percentile_rank <= 0) {
return sorted_records.front();
} else if (percentile_rank >= num_records - 1) {
return sorted_records.back();
}
// Interpolate.
DCHECK_GE(num_records, 2);
DCHECK_LT(percentile_rank, num_records - 1);
const int lower_rank = static_cast<int>(std::floor(percentile_rank));
DCHECK_LT(lower_rank, num_records - 1);
return sorted_records[lower_rank] +
(percentile_rank - lower_rank) *
(sorted_records[lower_rank + 1] - sorted_records[lower_rank]);
}
} // namespace sat
} // namespace operations_research

View File

@@ -14,6 +14,8 @@
#ifndef OR_TOOLS_SAT_UTIL_H_
#define OR_TOOLS_SAT_UTIL_H_
#include <deque>
#include "ortools/base/random.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
@@ -144,6 +146,31 @@ class ExponentialMovingAverage {
const double decaying_factor_;
};
// Utility to calculate percentile (First variant) for limited number of
// records. Reference: https://en.wikipedia.org/wiki/Percentile
//
// After the vector is sorted, we assume that the element with index i
// correspond to the percentile 100*(i+0.5)/size. For percentiles before the
// first element (resp. after the last one) we return the first element (resp.
// the last). And otherwise we do a linear interpolation between the two element
// around the asked percentile.
class Percentile {
public:
explicit Percentile(int record_limit) : record_limit_(record_limit) {}
void AddRecord(double record);
// Returns number of stored records.
int64 NumRecords() const { return records_.size(); }
// Note that this is not fast and runs in O(n log n) for n records.
double GetPercentile(double percent);
private:
std::deque<double> records_;
const int record_limit_;
};
} // namespace sat
} // namespace operations_research

View File

@@ -185,12 +185,15 @@ int64 Domain::Max() const {
return intervals_.back().end;
}
// TODO(user): binary search if size is large?
bool Domain::Contains(int64 value) const {
for (const ClosedInterval& interval : intervals_) {
if (interval.start <= value && interval.end >= value) return true;
}
return false;
// Because we only compare by start and there is no duplicate starts, this
// should be the next interval after the one that has a chance to contains
// value.
auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
ClosedInterval(value, value));
if (it == intervals_.begin()) return false;
--it;
return value <= it->end;
}
bool Domain::IsIncludedIn(const Domain& domain) const {

View File

@@ -27,8 +27,8 @@ namespace operations_research {
// Represents a closed interval [start, end]. We must have start <= end.
struct ClosedInterval {
int64 start; // Inclusive.
int64 end; // Inclusive.
ClosedInterval() {}
ClosedInterval(int64 s, int64 e) : start(s), end(e) {}
std::string DebugString() const;
bool operator==(const ClosedInterval& other) const {
@@ -41,6 +41,9 @@ struct ClosedInterval {
bool operator<(const ClosedInterval& other) const {
return start < other.start;
}
int64 start = 0; // Inclusive.
int64 end = 0; // Inclusive.
};
std::ostream& operator<<(std::ostream& out, const ClosedInterval& interval);