some memory optimization
This commit is contained in:
@@ -419,19 +419,20 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
const int size = element.vars_size();
|
||||
if (!context->IntersectDomainWith(index_ref, Domain(0, size - 1))) {
|
||||
VLOG(1) << "Empty domain for the index variable in ExpandElement()";
|
||||
CHECK(!context->NotifyThatModelIsUnsat());
|
||||
return;
|
||||
}
|
||||
|
||||
bool all_constants = true;
|
||||
std::set<int64> reached_values;
|
||||
absl::flat_hash_set<int64> constant_var_values;
|
||||
std::vector<int64> invalid_indices;
|
||||
const Domain initial_index_domain = context->DomainOf(index_ref);
|
||||
const Domain initial_target_domain = context->DomainOf(target_ref);
|
||||
for (const ClosedInterval& interval : initial_index_domain) {
|
||||
Domain index_domain = context->DomainOf(index_ref);
|
||||
Domain target_domain = context->DomainOf(target_ref);
|
||||
for (const ClosedInterval& interval : index_domain) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
const int var = element.vars(v);
|
||||
const Domain var_domain = context->DomainOf(var);
|
||||
if (var_domain.IntersectionWith(initial_target_domain).IsEmpty()) {
|
||||
if (var_domain.IntersectionWith(target_domain).IsEmpty()) {
|
||||
invalid_indices.push_back(v);
|
||||
continue;
|
||||
}
|
||||
@@ -439,7 +440,7 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
all_constants = false;
|
||||
break;
|
||||
}
|
||||
reached_values.insert(var_domain.Min());
|
||||
constant_var_values.insert(var_domain.Min());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -447,31 +448,40 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
if (!context->IntersectDomainWith(
|
||||
index_ref, Domain::FromValues(invalid_indices).Complement())) {
|
||||
VLOG(1) << "No compatible variable domains in ExpandElement()";
|
||||
CHECK(!context->NotifyThatModelIsUnsat());
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-read the domain.
|
||||
index_domain = context->DomainOf(index_ref);
|
||||
}
|
||||
|
||||
const Domain index_domain = context->DomainOf(index_ref);
|
||||
|
||||
std::map<int64, BoolArgumentProto*> supports;
|
||||
// This BoolOrs implements the deduction that if all index literals pointing
|
||||
// to the same values in the constant array are false, then this value is no
|
||||
// no longer valid for the target variable.
|
||||
// Order is not important.
|
||||
absl::flat_hash_map<int64, BoolArgumentProto*> supports;
|
||||
if (all_constants && target_ref != index_ref) {
|
||||
if (!context->IntersectDomainWith(
|
||||
target_ref, Domain::FromValues(
|
||||
{reached_values.begin(), reached_values.end()}))) {
|
||||
target_ref, Domain::FromValues({constant_var_values.begin(),
|
||||
constant_var_values.end()}))) {
|
||||
VLOG(1) << "Empty domain for the target variable in ExpandElement()";
|
||||
return;
|
||||
}
|
||||
|
||||
const Domain domain = context->DomainOf(target_ref);
|
||||
if (domain.Size() == 1) {
|
||||
context->UpdateRuleStats("element: array is constant");
|
||||
target_domain = context->DomainOf(target_ref);
|
||||
if (target_domain.Size() == 1) {
|
||||
context->UpdateRuleStats("element: one value array");
|
||||
ct->Clear();
|
||||
return;
|
||||
}
|
||||
|
||||
for (const ClosedInterval& interval : context->DomainOf(target_ref)) {
|
||||
// TODO(user): only create 1 literal if the value has only one support.
|
||||
|
||||
for (const ClosedInterval& interval : target_domain) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
const int lit = context->GetOrCreateVarValueEncoding(target_ref, v);
|
||||
CHECK(gtl::ContainsKey(reached_values, v));
|
||||
CHECK(constant_var_values.contains(v));
|
||||
supports[v] =
|
||||
context->working_model->add_constraints()->mutable_bool_or();
|
||||
supports[v]->add_literals(NegatedRef(lit));
|
||||
@@ -479,8 +489,6 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
const Domain target_domain = context->DomainOf(target_ref);
|
||||
|
||||
// While this is not stricly needed since all value in the index will be
|
||||
// covered, it allows to easily detect this fact in the presolve.
|
||||
auto* bool_or = context->working_model->add_constraints()->mutable_bool_or();
|
||||
@@ -502,7 +510,9 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
} else if (var_domain.Size() == 1) {
|
||||
context->AddImplyInDomain(index_lit, target_ref, var_domain);
|
||||
if (all_constants) {
|
||||
supports[var_domain.Min()]->add_literals(index_lit);
|
||||
BoolArgumentProto* const support =
|
||||
gtl::FindOrDie(supports, var_domain.Min());
|
||||
support->add_literals(index_lit);
|
||||
}
|
||||
} else {
|
||||
ConstraintProto* const ct = context->working_model->add_constraints();
|
||||
@@ -556,6 +566,9 @@ void ExpandCpModel(CpModelProto* working_model, PresolveOptions options) {
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
// Early exit if the model is unsat.
|
||||
if (context.ModelIsUnsat()) return;
|
||||
}
|
||||
|
||||
// Update any changed domain from the context.
|
||||
|
||||
@@ -96,11 +96,7 @@ void FillDomainInProto(const Domain& domain, ProtoWithDomain* proto) {
|
||||
// Reads a Domain from the domain field of a proto.
|
||||
template <typename ProtoWithDomain>
|
||||
Domain ReadDomainFromProto(const ProtoWithDomain& proto) {
|
||||
std::vector<ClosedInterval> intervals;
|
||||
for (int i = 0; i < proto.domain_size(); i += 2) {
|
||||
intervals.push_back({proto.domain(i), proto.domain(i + 1)});
|
||||
}
|
||||
return Domain::FromIntervals(intervals);
|
||||
return Domain::FromFlatSpanOfIntervals(proto.domain());
|
||||
}
|
||||
|
||||
// Returns the list of values in a given domain.
|
||||
|
||||
@@ -50,47 +50,46 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) {
|
||||
}
|
||||
|
||||
// Mark var and Negation(var) as fully encoded.
|
||||
CHECK_LT(var.value(), is_fully_encoded_.size());
|
||||
CHECK_LT(NegationOf(var).value(), is_fully_encoded_.size());
|
||||
CHECK(!equality_by_var_[var].empty());
|
||||
CHECK(!equality_by_var_[NegationOf(var)].empty());
|
||||
is_fully_encoded_[var] = true;
|
||||
is_fully_encoded_[NegationOf(var)] = true;
|
||||
CHECK_LT(GetPositiveOnlyIndex(var), is_fully_encoded_.size());
|
||||
CHECK(!equality_by_var_[GetPositiveOnlyIndex(var)].empty());
|
||||
is_fully_encoded_[GetPositiveOnlyIndex(var)] = true;
|
||||
}
|
||||
|
||||
bool IntegerEncoder::VariableIsFullyEncoded(IntegerVariable var) const {
|
||||
if (var >= is_fully_encoded_.size()) return false;
|
||||
const PositiveOnlyIndex index = GetPositiveOnlyIndex(var);
|
||||
if (index >= is_fully_encoded_.size()) return false;
|
||||
|
||||
// Once fully encoded, the status never changes.
|
||||
if (is_fully_encoded_[var]) return true;
|
||||
if (is_fully_encoded_[index]) return true;
|
||||
if (!VariableIsPositive(var)) var = PositiveVariable(var);
|
||||
|
||||
// TODO(user): Cache result as long as equality_by_var_[var] is unchanged?
|
||||
// TODO(user): Cache result as long as equality_by_var_[index] is unchanged?
|
||||
// It might not be needed since if the variable is not fully encoded, then
|
||||
// PartialDomainEncoding() will filter unreachable values, and so the size
|
||||
// check will be false until further value have been encoded.
|
||||
const int64 initial_domain_size = (*domains_)[var].Size();
|
||||
if (equality_by_var_[var].size() < initial_domain_size) return false;
|
||||
if (equality_by_var_[index].size() < initial_domain_size) return false;
|
||||
|
||||
// This cleans equality_by_var_[var] as a side effect and in particular, sorts
|
||||
// it by values.
|
||||
// This cleans equality_by_var_[index] as a side effect and in particular,
|
||||
// sorts it by values.
|
||||
PartialDomainEncoding(var);
|
||||
|
||||
// TODO(user): Comparing the size might be enough, but we want to be always
|
||||
// valid even if either (*domains_[var]) or PartialDomainEncoding(var) are
|
||||
// not properly synced because the propagation is not finished.
|
||||
const auto& ref = equality_by_var_[var];
|
||||
int index = 0;
|
||||
const auto& ref = equality_by_var_[index];
|
||||
int i = 0;
|
||||
for (const ClosedInterval interval : (*domains_)[var]) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
if (index < ref.size() && v == ref[index].value) {
|
||||
index++;
|
||||
if (i < ref.size() && v == ref[i].value) {
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (index == ref.size()) {
|
||||
is_fully_encoded_[var] = true;
|
||||
if (i == ref.size()) {
|
||||
is_fully_encoded_[index] = true;
|
||||
}
|
||||
return is_fully_encoded_[var];
|
||||
return is_fully_encoded_[index];
|
||||
}
|
||||
|
||||
std::vector<IntegerEncoder::ValueLiteralPair>
|
||||
@@ -102,23 +101,31 @@ IntegerEncoder::FullDomainEncoding(IntegerVariable var) const {
|
||||
std::vector<IntegerEncoder::ValueLiteralPair>
|
||||
IntegerEncoder::PartialDomainEncoding(IntegerVariable var) const {
|
||||
CHECK_EQ(sat_solver_->CurrentDecisionLevel(), 0);
|
||||
if (var >= equality_by_var_.size()) return {};
|
||||
const PositiveOnlyIndex index = GetPositiveOnlyIndex(var);
|
||||
if (index >= equality_by_var_.size()) return {};
|
||||
|
||||
int new_size = 0;
|
||||
std::vector<ValueLiteralPair>& ref = equality_by_var_[var];
|
||||
std::vector<ValueLiteralPair>& ref = equality_by_var_[index];
|
||||
for (int i = 0; i < ref.size(); ++i) {
|
||||
const ValueLiteralPair pair = ref[i];
|
||||
if (sat_solver_->Assignment().LiteralIsFalse(pair.literal)) continue;
|
||||
if (sat_solver_->Assignment().LiteralIsTrue(pair.literal)) {
|
||||
ref.clear();
|
||||
ref.push_back(pair);
|
||||
return ref;
|
||||
new_size = 1;
|
||||
break;
|
||||
}
|
||||
ref[new_size++] = pair;
|
||||
}
|
||||
ref.resize(new_size);
|
||||
std::sort(ref.begin(), ref.end());
|
||||
return ref;
|
||||
|
||||
std::vector<IntegerEncoder::ValueLiteralPair> result = ref;
|
||||
if (!VariableIsPositive(var)) {
|
||||
std::reverse(result.begin(), result.end());
|
||||
for (ValueLiteralPair& ref : result) ref.value = -ref.value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Note that by not inserting the literal in "order" we can in the worst case
|
||||
@@ -219,10 +226,18 @@ Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i_lit) {
|
||||
return literal;
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::pair<PositiveOnlyIndex, IntegerValue> PositiveVarKey(IntegerVariable var,
|
||||
IntegerValue value) {
|
||||
return std::make_pair(GetPositiveOnlyIndex(var),
|
||||
VariableIsPositive(var) ? value : -value);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
LiteralIndex IntegerEncoder::GetAssociatedEqualityLiteral(
|
||||
IntegerVariable var, IntegerValue value) const {
|
||||
const std::pair<IntegerVariable, IntegerValue> key{var, value};
|
||||
const auto it = equality_to_associated_literal_.find(key);
|
||||
const auto it =
|
||||
equality_to_associated_literal_.find(PositiveVarKey(var, value));
|
||||
if (it != equality_to_associated_literal_.end()) {
|
||||
return it->second.Index();
|
||||
}
|
||||
@@ -232,8 +247,8 @@ LiteralIndex IntegerEncoder::GetAssociatedEqualityLiteral(
|
||||
Literal IntegerEncoder::GetOrCreateLiteralAssociatedToEquality(
|
||||
IntegerVariable var, IntegerValue value) {
|
||||
{
|
||||
const std::pair<IntegerVariable, IntegerValue> key{var, value};
|
||||
const auto it = equality_to_associated_literal_.find(key);
|
||||
const auto it =
|
||||
equality_to_associated_literal_.find(PositiveVarKey(var, value));
|
||||
if (it != equality_to_associated_literal_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
@@ -315,8 +330,8 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
|
||||
|
||||
// We use the "do not insert if present" behavior of .insert() to do just one
|
||||
// lookup.
|
||||
const auto insert_result =
|
||||
equality_to_associated_literal_.insert({{var, value}, literal});
|
||||
const auto insert_result = equality_to_associated_literal_.insert(
|
||||
{PositiveVarKey(var, value), literal});
|
||||
if (!insert_result.second) {
|
||||
// If this key is already associated, make the two literals equal.
|
||||
const Literal representative = insert_result.first->second;
|
||||
@@ -327,8 +342,6 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
|
||||
}
|
||||
return;
|
||||
}
|
||||
gtl::InsertOrDieNoPrint(&equality_to_associated_literal_,
|
||||
{{NegationOf(var), -value}, literal});
|
||||
|
||||
// Fix literal for value outside the domain.
|
||||
if (!domain.Contains(value.value())) {
|
||||
@@ -339,14 +352,13 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
|
||||
// Update equality_by_var. Note that due to the
|
||||
// equality_to_associated_literal_ hash table, there should never be any
|
||||
// duplicate values for a given variable.
|
||||
const int needed_size = std::max(var.value(), NegationOf(var).value()) + 1;
|
||||
if (needed_size > equality_by_var_.size()) {
|
||||
equality_by_var_.resize(needed_size);
|
||||
is_fully_encoded_.resize(needed_size);
|
||||
const PositiveOnlyIndex index = GetPositiveOnlyIndex(var);
|
||||
if (index >= equality_by_var_.size()) {
|
||||
equality_by_var_.resize(index.value() + 1);
|
||||
is_fully_encoded_.resize(index.value() + 1);
|
||||
}
|
||||
equality_by_var_[var].push_back(ValueLiteralPair(value, literal));
|
||||
equality_by_var_[NegationOf(var)].push_back(
|
||||
ValueLiteralPair(-value, literal));
|
||||
equality_by_var_[index].push_back(
|
||||
ValueLiteralPair(VariableIsPositive(var) ? value : -value, literal));
|
||||
|
||||
// Fix literal for constant domain.
|
||||
if (value == domain.Min() && value == domain.Max()) {
|
||||
|
||||
@@ -136,6 +136,12 @@ inline IntegerVariable PositiveVariable(IntegerVariable i) {
|
||||
return IntegerVariable(i.value() & (~1));
|
||||
}
|
||||
|
||||
// Special type for storing only one thing for var and NegationOf(var).
|
||||
DEFINE_INT_TYPE(PositiveOnlyIndex, int32);
|
||||
inline PositiveOnlyIndex GetPositiveOnlyIndex(IntegerVariable var) {
|
||||
return PositiveOnlyIndex(var.value() / 2);
|
||||
}
|
||||
|
||||
// Returns the vector of the negated variables.
|
||||
std::vector<IntegerVariable> NegationOf(
|
||||
const std::vector<IntegerVariable>& vars);
|
||||
@@ -450,16 +456,18 @@ class IntegerEncoder {
|
||||
// Mapping (variable == value) -> associated literal. Note that even if
|
||||
// there is more than one literal associated to the same fact, we just keep
|
||||
// the first one that was added.
|
||||
absl::flat_hash_map<std::pair<IntegerVariable, IntegerValue>, Literal>
|
||||
//
|
||||
// Note that we only keep positive IntegerVariable here to reduce memory
|
||||
// usage.
|
||||
absl::flat_hash_map<std::pair<PositiveOnlyIndex, IntegerValue>, Literal>
|
||||
equality_to_associated_literal_;
|
||||
|
||||
// Mutable because this is lazily cleaned-up by PartialDomainEncoding().
|
||||
const std::vector<ValueLiteralPair> empty_value_literal_vector_;
|
||||
mutable gtl::ITIVector<IntegerVariable, std::vector<ValueLiteralPair>>
|
||||
mutable gtl::ITIVector<PositiveOnlyIndex, std::vector<ValueLiteralPair>>
|
||||
equality_by_var_;
|
||||
|
||||
// Variables that are fully encoded.
|
||||
mutable gtl::ITIVector<IntegerVariable, bool> is_fully_encoded_;
|
||||
mutable gtl::ITIVector<PositiveOnlyIndex, bool> is_fully_encoded_;
|
||||
|
||||
// A literal that is always true, convenient to encode trivial domains.
|
||||
// This will be lazily created when needed.
|
||||
|
||||
Reference in New Issue
Block a user