polish inverse expansion code
This commit is contained in:
@@ -371,60 +371,66 @@ void ExpandInverse(ConstraintProto* ct, PresolveContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce the domains of each variable by checking that the inverse value
|
||||
// exists.
|
||||
std::vector<int64> possible_values;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
possible_values.clear();
|
||||
const Domain domain = context->DomainOf(ct->inverse().f_direct(i));
|
||||
bool removed_value = false;
|
||||
for (const ClosedInterval& interval : domain) {
|
||||
for (int64 j = interval.start; j <= interval.end; ++j) {
|
||||
if (context->DomainOf(ct->inverse().f_inverse(j)).Contains(i)) {
|
||||
possible_values.push_back(j);
|
||||
} else {
|
||||
removed_value = true;
|
||||
// Propagate from one vector to its counterpart.
|
||||
// Note this reaches the fixpoint as there is a one to one mapping between
|
||||
// (variable-value) pairs in each vector.
|
||||
const auto filter_inverse_domain = [context, size, &possible_values](
|
||||
const auto& direct,
|
||||
const auto& inverse) {
|
||||
// Propagate for the inverse vector to the direct vector.
|
||||
for (int i = 0; i < size; ++i) {
|
||||
possible_values.clear();
|
||||
const Domain domain = context->DomainOf(direct[i]);
|
||||
bool removed_value = false;
|
||||
for (const ClosedInterval& interval : domain) {
|
||||
for (int64 j = interval.start; j <= interval.end; ++j) {
|
||||
if (context->DomainOf(inverse[j]).Contains(i)) {
|
||||
possible_values.push_back(j);
|
||||
} else {
|
||||
removed_value = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (removed_value) {
|
||||
if (!context->IntersectDomainWith(
|
||||
direct[i], Domain::FromValues(possible_values))) {
|
||||
VLOG(1) << "Empty domain for a variable in ExpandInverse()";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (removed_value) {
|
||||
if (!context->IntersectDomainWith(ct->inverse().f_direct(i),
|
||||
Domain::FromValues(possible_values))) {
|
||||
VLOG(1) << "Empty domain for a variable in ExpandInverse()";
|
||||
return;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!filter_inverse_domain(ct->inverse().f_direct(),
|
||||
ct->inverse().f_inverse())) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int j = 0; j < size; ++j) {
|
||||
possible_values.clear();
|
||||
const Domain domain = context->DomainOf(ct->inverse().f_inverse(j));
|
||||
bool removed_value = false;
|
||||
for (const ClosedInterval& interval : domain) {
|
||||
for (int64 i = interval.start; i <= interval.end; ++i) {
|
||||
if (context->DomainOf(ct->inverse().f_direct(i)).Contains(j)) {
|
||||
possible_values.push_back(i);
|
||||
} else {
|
||||
removed_value = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (removed_value) {
|
||||
if (!context->IntersectDomainWith(ct->inverse().f_inverse(j),
|
||||
Domain::FromValues(possible_values))) {
|
||||
VLOG(1) << "Empty domain for a variable in ExpandInverse()";
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!filter_inverse_domain(ct->inverse().f_inverse(),
|
||||
ct->inverse().f_direct())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Implement the inverse part.
|
||||
// Expand the inverse constraint by associating literal to var == value
|
||||
// and sharing them between the direct and inverse variables.
|
||||
for (int i = 0; i < size; ++i) {
|
||||
const int f_i = ct->inverse().f_direct(i);
|
||||
const Domain domain = context->DomainOf(f_i);
|
||||
for (const ClosedInterval& interval : domain) {
|
||||
for (int64 j = interval.start; j <= interval.end; ++j) {
|
||||
// We have f[i] == j <=> r[j] == i;
|
||||
const int r_j = ct->inverse().f_inverse(j);
|
||||
const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j);
|
||||
context->InsertVarValueEncoding(f_i_j, r_j, i);
|
||||
int r_j_i;
|
||||
if (context->HasVarValueEncoding(r_j, i, &r_j_i)) {
|
||||
context->InsertVarValueEncoding(r_j_i, f_i, j);
|
||||
} else {
|
||||
const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j);
|
||||
context->InsertVarValueEncoding(f_i_j, r_j, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -479,6 +479,21 @@ bool PresolveContext::StoreLiteralImpliesVarNEqValue(int literal, int var,
|
||||
return InsertHalfVarValueEncoding(literal, var, value, /*imply_eq=*/false);
|
||||
}
|
||||
|
||||
bool PresolveContext::HasVarValueEncoding(int ref, int64 value, int* literal) {
|
||||
const int var = PositiveRef(ref);
|
||||
const int64 var_value = RefIsPositive(ref) ? value : -value;
|
||||
const std::pair<int, int64> key{var, var_value};
|
||||
const auto& it = encoding.find(key);
|
||||
if (it != encoding.end()) {
|
||||
if (literal != nullptr) {
|
||||
*literal = it->second;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) {
|
||||
// TODO(user,user): use affine relation here.
|
||||
const int var = PositiveRef(ref);
|
||||
|
||||
@@ -147,6 +147,10 @@ struct PresolveContext {
|
||||
// create it, add the corresponding constraints and returns it.
|
||||
int GetOrCreateVarValueEncoding(int ref, int64 value);
|
||||
|
||||
// Returns true if a literal attached to ref == var exists.
|
||||
// It assigns the corresponding to `literal` if non null.
|
||||
bool HasVarValueEncoding(int ref, int64 value, int* literal = nullptr);
|
||||
|
||||
// Stores the fact that literal implies var == value.
|
||||
// It returns true if that information is new.
|
||||
bool StoreLiteralImpliesVarEqValue(int literal, int var, int64 value);
|
||||
|
||||
Reference in New Issue
Block a user