diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index 14aa06f463..f29a5f82ca 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -371,60 +371,66 @@ void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { } } + // Reduce the domains of each variable by checking that the inverse value + // exists. std::vector possible_values; - for (int i = 0; i < size; ++i) { - possible_values.clear(); - const Domain domain = context->DomainOf(ct->inverse().f_direct(i)); - bool removed_value = false; - for (const ClosedInterval& interval : domain) { - for (int64 j = interval.start; j <= interval.end; ++j) { - if (context->DomainOf(ct->inverse().f_inverse(j)).Contains(i)) { - possible_values.push_back(j); - } else { - removed_value = true; + // Propagate from one vector to its counterpart. + // Note this reaches the fixpoint as there is a one to one mapping between + // (variable-value) pairs in each vector. + const auto filter_inverse_domain = [context, size, &possible_values]( + const auto& direct, + const auto& inverse) { + // Propagate for the inverse vector to the direct vector. + for (int i = 0; i < size; ++i) { + possible_values.clear(); + const Domain domain = context->DomainOf(direct[i]); + bool removed_value = false; + for (const ClosedInterval& interval : domain) { + for (int64 j = interval.start; j <= interval.end; ++j) { + if (context->DomainOf(inverse[j]).Contains(i)) { + possible_values.push_back(j); + } else { + removed_value = true; + } + } + } + if (removed_value) { + if (!context->IntersectDomainWith( + direct[i], Domain::FromValues(possible_values))) { + VLOG(1) << "Empty domain for a variable in ExpandInverse()"; + return false; } } } - if (removed_value) { - if (!context->IntersectDomainWith(ct->inverse().f_direct(i), - Domain::FromValues(possible_values))) { - VLOG(1) << "Empty domain for a variable in ExpandInverse()"; - return; - } - } + return true; + }; + + if (!filter_inverse_domain(ct->inverse().f_direct(), + ct->inverse().f_inverse())) { + return; } - for (int j = 0; j < size; ++j) { - possible_values.clear(); - const Domain domain = context->DomainOf(ct->inverse().f_inverse(j)); - bool removed_value = false; - for (const ClosedInterval& interval : domain) { - for (int64 i = interval.start; i <= interval.end; ++i) { - if (context->DomainOf(ct->inverse().f_direct(i)).Contains(j)) { - possible_values.push_back(i); - } else { - removed_value = true; - } - } - } - if (removed_value) { - if (!context->IntersectDomainWith(ct->inverse().f_inverse(j), - Domain::FromValues(possible_values))) { - VLOG(1) << "Empty domain for a variable in ExpandInverse()"; - return; - } - } + if (!filter_inverse_domain(ct->inverse().f_inverse(), + ct->inverse().f_direct())) { + return; } - // Implement the inverse part. + // Expand the inverse constraint by associating literal to var == value + // and sharing them between the direct and inverse variables. for (int i = 0; i < size; ++i) { const int f_i = ct->inverse().f_direct(i); const Domain domain = context->DomainOf(f_i); for (const ClosedInterval& interval : domain) { for (int64 j = interval.start; j <= interval.end; ++j) { + // We have f[i] == j <=> r[j] == i; const int r_j = ct->inverse().f_inverse(j); - const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j); - context->InsertVarValueEncoding(f_i_j, r_j, i); + int r_j_i; + if (context->HasVarValueEncoding(r_j, i, &r_j_i)) { + context->InsertVarValueEncoding(r_j_i, f_i, j); + } else { + const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j); + context->InsertVarValueEncoding(f_i_j, r_j, i); + } } } } diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index ccb2f2bc21..c6c7563a8e 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -479,6 +479,21 @@ bool PresolveContext::StoreLiteralImpliesVarNEqValue(int literal, int var, return InsertHalfVarValueEncoding(literal, var, value, /*imply_eq=*/false); } +bool PresolveContext::HasVarValueEncoding(int ref, int64 value, int* literal) { + const int var = PositiveRef(ref); + const int64 var_value = RefIsPositive(ref) ? value : -value; + const std::pair key{var, var_value}; + const auto& it = encoding.find(key); + if (it != encoding.end()) { + if (literal != nullptr) { + *literal = it->second; + } + return true; + } else { + return false; + } +} + int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) { // TODO(user,user): use affine relation here. const int var = PositiveRef(ref); diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index c5a39e514d..aa786b3621 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -147,6 +147,10 @@ struct PresolveContext { // create it, add the corresponding constraints and returns it. int GetOrCreateVarValueEncoding(int ref, int64 value); + // Returns true if a literal attached to ref == var exists. + // It assigns the corresponding to `literal` if non null. + bool HasVarValueEncoding(int ref, int64 value, int* literal = nullptr); + // Stores the fact that literal implies var == value. // It returns true if that information is new. bool StoreLiteralImpliesVarEqValue(int literal, int var, int64 value);