From 88054f79e687174ea67ebc95dca7c06b35659b06 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 2 Sep 2019 12:00:34 +0200 Subject: [PATCH] big bang on CP-SAT: remove element and inverse constraints, expand them --- ortools/sat/cp_model_expand.cc | 71 ++++++++++++++++++++++-- ortools/sat/presolve_util.cc | 99 ++++++++++++++++++++++++++++++++++ ortools/sat/presolve_util.h | 80 +++++++++++++++++++++++++++ 3 files changed, 247 insertions(+), 3 deletions(-) create mode 100644 ortools/sat/presolve_util.cc create mode 100644 ortools/sat/presolve_util.h diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index c854f2bec7..3dd1b8dca6 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -415,9 +415,66 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) { const int size = element.vars_size(); if (!context->IntersectDomainWith(index_ref, Domain(0, size - 1))) { VLOG(1) << "Empty domain for the index variable in ExpandElement()"; + return; + } + + bool all_constants = true; + std::set reached_values; + std::vector invalid_indices; + const Domain initial_index_domain = context->DomainOf(index_ref); + const Domain initial_target_domain = context->DomainOf(target_ref); + for (const ClosedInterval& interval : initial_index_domain) { + for (int64 v = interval.start; v <= interval.end; ++v) { + const int var = element.vars(v); + const Domain var_domain = context->DomainOf(var); + if (var_domain.IntersectionWith(initial_target_domain).IsEmpty()) { + invalid_indices.push_back(v); + continue; + } + if (var_domain.Min() != var_domain.Max()) { + all_constants = false; + break; + } + reached_values.insert(var_domain.Min()); + } + } + + if (!invalid_indices.empty()) { + if (!context->IntersectDomainWith( + index_ref, Domain::FromValues(invalid_indices).Complement())) { + VLOG(1) << "No compatible variable domains in ExpandElement()"; + return; + } } const Domain index_domain = context->DomainOf(index_ref); + + std::map supports; + if (all_constants && target_ref != index_ref) { + if (!context->IntersectDomainWith( + target_ref, Domain::FromValues( + {reached_values.begin(), reached_values.end()}))) { + VLOG(1) << "Empty domain for the target variable in ExpandElement()"; + return; + } + + const Domain domain = context->DomainOf(target_ref); + if (domain.Size() == 1) { + context->UpdateRuleStats("element: array is constant"); + return; + } + + for (const ClosedInterval& interval : context->DomainOf(target_ref)) { + for (int64 v = interval.start; v <= interval.end; ++v) { + const int lit = context->GetOrCreateVarValueEncoding(target_ref, v); + CHECK(gtl::ContainsKey(reached_values, v)); + supports[v] = + context->working_model->add_constraints()->mutable_bool_or(); + supports[v]->add_literals(NegatedRef(lit)); + } + } + } + const Domain target_domain = context->DomainOf(target_ref); // While this is not stricly needed since all value in the index will be @@ -433,13 +490,16 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) { bool_or->add_literals(index_lit); if (target_ref == index_ref) { - // This adds extra code. But this information is really important, and - // hard to retrieve once lost. + // This adds extra code. But this information is really important, + // and hard to retrieve once lost. context->AddImplyInDomain(index_lit, var, Domain(v)); } else if (target_domain.Size() == 1) { context->AddImplyInDomain(index_lit, var, target_domain); } else if (var_domain.Size() == 1) { context->AddImplyInDomain(index_lit, target_ref, var_domain); + if (all_constants) { + supports[var_domain.Min()]->add_literals(index_lit); + } } else { ConstraintProto* const ct = context->working_model->add_constraints(); ct->add_enforcement_literal(index_lit); @@ -452,7 +512,12 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) { } } } - context->UpdateRuleStats("element: expanded"); + + if (all_constants) { + context->UpdateRuleStats("element: expanded value element"); + } else { + context->UpdateRuleStats("element: expanded"); + } ct->Clear(); } diff --git a/ortools/sat/presolve_util.cc b/ortools/sat/presolve_util.cc new file mode 100644 index 0000000000..0d8a55bd85 --- /dev/null +++ b/ortools/sat/presolve_util.cc @@ -0,0 +1,99 @@ +// Copyright 2010-2018 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/presolve_util.h" + +#include "ortools/base/map_util.h" + +namespace operations_research { +namespace sat { + +void DomainDeductions::AddDeduction(int literal_ref, int var, Domain domain) { + CHECK_GE(var, 0); + const Index index = IndexFromLiteral(literal_ref); + if (index >= something_changed_.size()) { + something_changed_.Resize(index + 1); + enforcement_to_vars_.resize(index.value() + 1); + } + if (var >= tmp_num_occurences_.size()) { + tmp_num_occurences_.resize(var + 1, 0); + } + const auto insert = deductions_.insert({{index, var}, domain}); + if (insert.second) { + // New element. + something_changed_.Set(index); + enforcement_to_vars_[index].push_back(var); + } else { + // Existing element. + const Domain& old_domain = insert.first->second; + if (!old_domain.IsIncludedIn(domain)) { + insert.first->second = domain.IntersectionWith(old_domain); + something_changed_.Set(index); + } + } +} + +std::vector> DomainDeductions::ProcessClause( + absl::Span clause) { + std::vector> result; + + // We only need to process this clause if something changed since last time. + bool abort = true; + for (const int ref : clause) { + const Index index = IndexFromLiteral(ref); + if (index >= something_changed_.size()) return result; + if (something_changed_[index]) { + abort = false; + } + } + if (abort) return result; + + // Count for each variable, how many times it appears in the deductions lists. + std::vector to_process; + std::vector to_clean; + for (const int ref : clause) { + const Index index = IndexFromLiteral(ref); + for (const int var : enforcement_to_vars_[index]) { + if (tmp_num_occurences_[var] == 0) { + to_clean.push_back(var); + } + tmp_num_occurences_[var]++; + if (tmp_num_occurences_[var] == clause.size()) { + to_process.push_back(var); + } + } + } + + // Clear the counts. + for (const int var : to_clean) { + tmp_num_occurences_[var] = 0; + } + + // Compute the domain unions. + std::vector domains(to_process.size()); + for (const int ref : clause) { + const Index index = IndexFromLiteral(ref); + for (int i = 0; i < to_process.size(); ++i) { + domains[i] = domains[i].UnionWith( + gtl::FindOrDieNoPrint(deductions_, {index, to_process[i]})); + } + } + + for (int i = 0; i < to_process.size(); ++i) { + result.push_back({to_process[i], std::move(domains[i])}); + } + return result; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/presolve_util.h b/ortools/sat/presolve_util.h new file mode 100644 index 0000000000..c284603454 --- /dev/null +++ b/ortools/sat/presolve_util.h @@ -0,0 +1,80 @@ +// Copyright 2010-2018 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_PRESOLVE_UTIL_H_ +#define OR_TOOLS_SAT_PRESOLVE_UTIL_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "ortools/base/int_type.h" +#include "ortools/base/int_type_indexed_vector.h" +#include "ortools/base/integral_types.h" +#include "ortools/base/logging.h" +#include "ortools/util/bitset.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { + +// If for each literal of a clause, we can infer a domain on an integer +// variable, then we know that this variable domain is included in the union of +// such infered domains. +// +// This allows to propagate "element" like constraints encoded as enforced +// linear relations, and other more general reasoning. +// +// TODO(user): Also use these "deductions" in the solver directly. This is done +// in good MIP solvers, and we should exploit them more. +class DomainDeductions { + public: + // Adds the fact that enforcement => var \in domain. + void AddDeduction(int literal_ref, int var, Domain domain); + + // Returns list of (var, domain) that were deduced because: + // 1/ We have a domain deduction for var and all literal from the clause + // 2/ So we can take the union of all the deduced domains. + // + // TODO(user): We could probably be even more efficient. We could also + // compute exactly what clauses need to be "waked up" as new deductions are + // added. + std::vector> ProcessClause( + absl::Span clause); + + // Optimization. Any following ProcessClause() will be fast if no more + // deduction touching that clause are added. + void MarkProcessingAsDoneForNow() { + something_changed_.ClearAndResize(something_changed_.size()); + } + + // Returns the total number of "deductions" stored by this class. + int NumDeductions() const { return deductions_.size(); } + + private: + DEFINE_INT_TYPE(Index, int); + Index IndexFromLiteral(int ref) { + return Index(ref >= 0 ? 2 * ref : -2 * ref - 1); + } + + std::vector tmp_num_occurences_; + + SparseBitset something_changed_; + gtl::ITIVector> enforcement_to_vars_; + absl::flat_hash_map, Domain> deductions_; +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_PRESOLVE_UTIL_H_