optimize speed and memory
This commit is contained in:
@@ -131,12 +131,19 @@ bool PresolveContext::VariableIsUniqueAndRemovable(int ref) const {
|
||||
}
|
||||
|
||||
Domain PresolveContext::DomainOf(int ref) const {
|
||||
if (RefIsPositive(ref)) return domains[ref];
|
||||
return domains[PositiveRef(ref)].Negation();
|
||||
Domain result;
|
||||
if (RefIsPositive(ref)) {
|
||||
result = domains[ref];
|
||||
} else {
|
||||
result = domains[PositiveRef(ref)].Negation();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool PresolveContext::DomainContains(int ref, int64 value) const {
|
||||
if (!RefIsPositive(ref)) return DomainContains(NegatedRef(ref), -value);
|
||||
if (!RefIsPositive(ref)) {
|
||||
return domains[PositiveRef(ref)].Contains(-value);
|
||||
}
|
||||
return domains[ref].Contains(value);
|
||||
}
|
||||
|
||||
@@ -797,9 +804,9 @@ bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) {
|
||||
// infered_domain ∩ [kint64min, target_ub] ⊂ target_domain
|
||||
// then the constraint is really max(...) <= target_ub and we can simplify it.
|
||||
if (context->VariableIsUniqueAndRemovable(target_ref)) {
|
||||
const Domain target_domain = context->DomainOf(target_ref);
|
||||
const Domain& target_domain = context->DomainOf(target_ref);
|
||||
if (infered_domain.IntersectionWith(Domain(kint64min, target_domain.Max()))
|
||||
.IsIncludedIn(context->DomainOf(target_ref))) {
|
||||
.IsIncludedIn(target_domain)) {
|
||||
if (infered_domain.Max() <= target_domain.Max()) {
|
||||
// The constraint is always satisfiable.
|
||||
context->UpdateRuleStats("int_max: always true");
|
||||
@@ -1783,21 +1790,19 @@ bool PresolveInterval(int c, ConstraintProto* ct, PresolveContext* context) {
|
||||
|
||||
if (!ct->enforcement_literal().empty()) return false;
|
||||
bool changed = false;
|
||||
const Domain start_domain = context->DomainOf(start);
|
||||
const Domain end_domain = context->DomainOf(end);
|
||||
const Domain size_domain = context->DomainOf(size);
|
||||
if (!context->IntersectDomainWith(end, start_domain.AdditionWith(size_domain),
|
||||
&changed)) {
|
||||
return false;
|
||||
}
|
||||
if (!context->IntersectDomainWith(
|
||||
end, context->DomainOf(start).AdditionWith(context->DomainOf(size)),
|
||||
&changed)) {
|
||||
start, end_domain.AdditionWith(size_domain.Negation()), &changed)) {
|
||||
return false;
|
||||
}
|
||||
if (!context->IntersectDomainWith(start,
|
||||
context->DomainOf(end).AdditionWith(
|
||||
context->DomainOf(size).Negation()),
|
||||
&changed)) {
|
||||
return false;
|
||||
}
|
||||
if (!context->IntersectDomainWith(size,
|
||||
context->DomainOf(end).AdditionWith(
|
||||
context->DomainOf(start).Negation()),
|
||||
&changed)) {
|
||||
if (!context->IntersectDomainWith(
|
||||
size, end_domain.AdditionWith(start_domain.Negation()), &changed)) {
|
||||
return false;
|
||||
}
|
||||
if (changed) {
|
||||
@@ -1844,14 +1849,14 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
// Filter possible index values. Accumulate variable domains to build
|
||||
// a possible target domain.
|
||||
Domain infered_domain;
|
||||
const Domain initial_index_domain = context->DomainOf(index_ref);
|
||||
Domain target_domain = context->DomainOf(target_ref);
|
||||
const Domain& initial_index_domain = context->DomainOf(index_ref);
|
||||
const Domain& target_domain = context->DomainOf(target_ref);
|
||||
for (const ClosedInterval interval : initial_index_domain) {
|
||||
for (int value = interval.start; value <= interval.end; ++value) {
|
||||
CHECK_GE(value, 0);
|
||||
CHECK_LT(value, ct->element().vars_size());
|
||||
const int ref = ct->element().vars(value);
|
||||
const Domain domain = context->DomainOf(ref);
|
||||
const Domain& domain = context->DomainOf(ref);
|
||||
if (domain.IntersectionWith(target_domain).IsEmpty()) {
|
||||
bool domain_modified = false;
|
||||
if (!context->IntersectDomainWith(
|
||||
@@ -2000,7 +2005,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
// Eventually, domain sizes will be synchronized.
|
||||
bool changed_values = false;
|
||||
std::vector<int64> valid_index_values;
|
||||
const Domain index_domain = context->DomainOf(index_ref);
|
||||
const Domain& index_domain = context->DomainOf(index_ref);
|
||||
for (const ClosedInterval interval : index_domain) {
|
||||
for (int i = interval.start; i <= interval.end; ++i) {
|
||||
const int64 value = context->MinOf(ct->element().vars(i));
|
||||
@@ -2035,10 +2040,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
}
|
||||
|
||||
if (context->IsFixed(target_ref)) {
|
||||
const Domain index_domain = context->DomainOf(index_ref);
|
||||
const int64 target_value = context->MinOf(target_ref);
|
||||
|
||||
for (const ClosedInterval& interval : index_domain) {
|
||||
for (const ClosedInterval& interval : context->DomainOf(index_ref)) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
const int var = ct->element().vars(v);
|
||||
const int index_lit =
|
||||
@@ -2054,8 +2057,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
|
||||
if (target_ref == index_ref) {
|
||||
// Filter impossible index values.
|
||||
Domain index_domain = context->DomainOf(index_ref);
|
||||
std::vector<int64> possible_indices;
|
||||
const Domain& index_domain = context->DomainOf(index_ref);
|
||||
for (const ClosedInterval& interval : index_domain) {
|
||||
for (int64 value = interval.start; value <= interval.end; ++value) {
|
||||
const int ref = ct->element().vars(value);
|
||||
@@ -2077,7 +2080,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
}
|
||||
|
||||
if (all_constants) {
|
||||
const Domain index_domain = context->DomainOf(index_ref);
|
||||
absl::flat_hash_map<int, std::vector<int>> supports;
|
||||
|
||||
// Help linearization.
|
||||
@@ -2087,7 +2089,7 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
lin->add_coeffs(-1);
|
||||
int64 rhs = 0;
|
||||
|
||||
for (const ClosedInterval& interval : index_domain) {
|
||||
for (const ClosedInterval& interval : context->DomainOf(index_ref)) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
const int64 value = context->MinOf(ct->element().vars(v));
|
||||
const int index_lit =
|
||||
@@ -2184,7 +2186,7 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) {
|
||||
v = inverse_value;
|
||||
}
|
||||
tuple[j] = v;
|
||||
if (!context->DomainOf(r.representative).Contains(v)) {
|
||||
if (!context->DomainContains(r.representative, v)) {
|
||||
delete_row = true;
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -637,26 +637,22 @@ void LinearProgrammingConstraint::UpdateSimplexIterationLimit(
|
||||
if (sat_parameters_.linearization_level() < 2) return;
|
||||
const int64 num_degenerate_columns = CalculateDegeneracy();
|
||||
const int64 num_cols = simplex_.GetProblemNumCols().value();
|
||||
const bool high_degeneracy = num_degenerate_columns >= 0.5 * num_cols;
|
||||
const bool medium_degeneracy = num_degenerate_columns >= 0.3 * num_cols;
|
||||
const bool is_degenerate = num_degenerate_columns >= 0.3 * num_cols;
|
||||
const int64 decrease_factor = 10 * num_degenerate_columns / num_cols;
|
||||
if (simplex_.GetProblemStatus() == glop::ProblemStatus::DUAL_FEASIBLE) {
|
||||
// We reached here probably because we predicted wrong. We use this as a
|
||||
// signal to increase the iterations or punish less for degeneracy compare
|
||||
// to the other part.
|
||||
// TODO(user): Derive a formula to update the limit using degeneracy to
|
||||
// simplify the code.
|
||||
if (high_degeneracy) {
|
||||
next_simplex_iter_ /= 5;
|
||||
} else if (medium_degeneracy) {
|
||||
next_simplex_iter_ /= 2;
|
||||
if (is_degenerate) {
|
||||
next_simplex_iter_ /= decrease_factor;
|
||||
} else {
|
||||
next_simplex_iter_ *= 2;
|
||||
}
|
||||
} else if (simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL) {
|
||||
if (high_degeneracy) {
|
||||
next_simplex_iter_ /= 10;
|
||||
} else if (medium_degeneracy) {
|
||||
next_simplex_iter_ /= 5;
|
||||
if (is_degenerate) {
|
||||
next_simplex_iter_ /= 2 * decrease_factor;
|
||||
} else {
|
||||
// This is the most common case. We use the size of the problem to
|
||||
// determine the limit and ignore the previous limit.
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ortools/sat/util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
@@ -74,5 +75,37 @@ void ExponentialMovingAverage::AddData(double new_record) {
|
||||
: (new_record + decaying_factor_ * (average_ - new_record));
|
||||
}
|
||||
|
||||
void Percentile::AddRecord(double record) {
|
||||
records_.push_front(record);
|
||||
if (records_.size() > record_limit_) {
|
||||
records_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
double Percentile::GetPercentile(double percent) {
|
||||
CHECK_GT(records_.size(), 0);
|
||||
CHECK_LE(percent, 100.0);
|
||||
CHECK_GE(percent, 0.0);
|
||||
std::vector<double> sorted_records(records_.begin(), records_.end());
|
||||
std::sort(sorted_records.begin(), sorted_records.end());
|
||||
const int num_records = sorted_records.size();
|
||||
|
||||
const double percentile_rank =
|
||||
static_cast<double>(num_records) * percent / 100.0 - 0.5;
|
||||
if (percentile_rank <= 0) {
|
||||
return sorted_records.front();
|
||||
} else if (percentile_rank >= num_records - 1) {
|
||||
return sorted_records.back();
|
||||
}
|
||||
// Interpolate.
|
||||
DCHECK_GE(num_records, 2);
|
||||
DCHECK_LT(percentile_rank, num_records - 1);
|
||||
const int lower_rank = static_cast<int>(std::floor(percentile_rank));
|
||||
DCHECK_LT(lower_rank, num_records - 1);
|
||||
return sorted_records[lower_rank] +
|
||||
(percentile_rank - lower_rank) *
|
||||
(sorted_records[lower_rank + 1] - sorted_records[lower_rank]);
|
||||
}
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#ifndef OR_TOOLS_SAT_UTIL_H_
|
||||
#define OR_TOOLS_SAT_UTIL_H_
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "ortools/base/random.h"
|
||||
#include "ortools/sat/model.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
@@ -144,6 +146,31 @@ class ExponentialMovingAverage {
|
||||
const double decaying_factor_;
|
||||
};
|
||||
|
||||
// Utility to calculate percentile (First variant) for limited number of
|
||||
// records. Reference: https://en.wikipedia.org/wiki/Percentile
|
||||
//
|
||||
// After the vector is sorted, we assume that the element with index i
|
||||
// correspond to the percentile 100*(i+0.5)/size. For percentiles before the
|
||||
// first element (resp. after the last one) we return the first element (resp.
|
||||
// the last). And otherwise we do a linear interpolation between the two element
|
||||
// around the asked percentile.
|
||||
class Percentile {
|
||||
public:
|
||||
explicit Percentile(int record_limit) : record_limit_(record_limit) {}
|
||||
|
||||
void AddRecord(double record);
|
||||
|
||||
// Returns number of stored records.
|
||||
int64 NumRecords() const { return records_.size(); }
|
||||
|
||||
// Note that this is not fast and runs in O(n log n) for n records.
|
||||
double GetPercentile(double percent);
|
||||
|
||||
private:
|
||||
std::deque<double> records_;
|
||||
const int record_limit_;
|
||||
};
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
|
||||
|
||||
@@ -185,12 +185,15 @@ int64 Domain::Max() const {
|
||||
return intervals_.back().end;
|
||||
}
|
||||
|
||||
// TODO(user): binary search if size is large?
|
||||
bool Domain::Contains(int64 value) const {
|
||||
for (const ClosedInterval& interval : intervals_) {
|
||||
if (interval.start <= value && interval.end >= value) return true;
|
||||
}
|
||||
return false;
|
||||
// Because we only compare by start and there is no duplicate starts, this
|
||||
// should be the next interval after the one that has a chance to contains
|
||||
// value.
|
||||
auto it = std::upper_bound(intervals_.begin(), intervals_.end(),
|
||||
ClosedInterval(value, value));
|
||||
if (it == intervals_.begin()) return false;
|
||||
--it;
|
||||
return value <= it->end;
|
||||
}
|
||||
|
||||
bool Domain::IsIncludedIn(const Domain& domain) const {
|
||||
|
||||
@@ -27,8 +27,8 @@ namespace operations_research {
|
||||
|
||||
// Represents a closed interval [start, end]. We must have start <= end.
|
||||
struct ClosedInterval {
|
||||
int64 start; // Inclusive.
|
||||
int64 end; // Inclusive.
|
||||
ClosedInterval() {}
|
||||
ClosedInterval(int64 s, int64 e) : start(s), end(e) {}
|
||||
|
||||
std::string DebugString() const;
|
||||
bool operator==(const ClosedInterval& other) const {
|
||||
@@ -41,6 +41,9 @@ struct ClosedInterval {
|
||||
bool operator<(const ClosedInterval& other) const {
|
||||
return start < other.start;
|
||||
}
|
||||
|
||||
int64 start = 0; // Inclusive.
|
||||
int64 end = 0; // Inclusive.
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ClosedInterval& interval);
|
||||
|
||||
Reference in New Issue
Block a user