big bang on CP-SAT: remove element and inverse constraints, expand them
This commit is contained in:
@@ -415,9 +415,66 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
const int size = element.vars_size();
|
||||
if (!context->IntersectDomainWith(index_ref, Domain(0, size - 1))) {
|
||||
VLOG(1) << "Empty domain for the index variable in ExpandElement()";
|
||||
return;
|
||||
}
|
||||
|
||||
bool all_constants = true;
|
||||
std::set<int64> reached_values;
|
||||
std::vector<int64> invalid_indices;
|
||||
const Domain initial_index_domain = context->DomainOf(index_ref);
|
||||
const Domain initial_target_domain = context->DomainOf(target_ref);
|
||||
for (const ClosedInterval& interval : initial_index_domain) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
const int var = element.vars(v);
|
||||
const Domain var_domain = context->DomainOf(var);
|
||||
if (var_domain.IntersectionWith(initial_target_domain).IsEmpty()) {
|
||||
invalid_indices.push_back(v);
|
||||
continue;
|
||||
}
|
||||
if (var_domain.Min() != var_domain.Max()) {
|
||||
all_constants = false;
|
||||
break;
|
||||
}
|
||||
reached_values.insert(var_domain.Min());
|
||||
}
|
||||
}
|
||||
|
||||
if (!invalid_indices.empty()) {
|
||||
if (!context->IntersectDomainWith(
|
||||
index_ref, Domain::FromValues(invalid_indices).Complement())) {
|
||||
VLOG(1) << "No compatible variable domains in ExpandElement()";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const Domain index_domain = context->DomainOf(index_ref);
|
||||
|
||||
std::map<int64, BoolArgumentProto*> supports;
|
||||
if (all_constants && target_ref != index_ref) {
|
||||
if (!context->IntersectDomainWith(
|
||||
target_ref, Domain::FromValues(
|
||||
{reached_values.begin(), reached_values.end()}))) {
|
||||
VLOG(1) << "Empty domain for the target variable in ExpandElement()";
|
||||
return;
|
||||
}
|
||||
|
||||
const Domain domain = context->DomainOf(target_ref);
|
||||
if (domain.Size() == 1) {
|
||||
context->UpdateRuleStats("element: array is constant");
|
||||
return;
|
||||
}
|
||||
|
||||
for (const ClosedInterval& interval : context->DomainOf(target_ref)) {
|
||||
for (int64 v = interval.start; v <= interval.end; ++v) {
|
||||
const int lit = context->GetOrCreateVarValueEncoding(target_ref, v);
|
||||
CHECK(gtl::ContainsKey(reached_values, v));
|
||||
supports[v] =
|
||||
context->working_model->add_constraints()->mutable_bool_or();
|
||||
supports[v]->add_literals(NegatedRef(lit));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Domain target_domain = context->DomainOf(target_ref);
|
||||
|
||||
// While this is not stricly needed since all value in the index will be
|
||||
@@ -433,13 +490,16 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
bool_or->add_literals(index_lit);
|
||||
|
||||
if (target_ref == index_ref) {
|
||||
// This adds extra code. But this information is really important, and
|
||||
// hard to retrieve once lost.
|
||||
// This adds extra code. But this information is really important,
|
||||
// and hard to retrieve once lost.
|
||||
context->AddImplyInDomain(index_lit, var, Domain(v));
|
||||
} else if (target_domain.Size() == 1) {
|
||||
context->AddImplyInDomain(index_lit, var, target_domain);
|
||||
} else if (var_domain.Size() == 1) {
|
||||
context->AddImplyInDomain(index_lit, target_ref, var_domain);
|
||||
if (all_constants) {
|
||||
supports[var_domain.Min()]->add_literals(index_lit);
|
||||
}
|
||||
} else {
|
||||
ConstraintProto* const ct = context->working_model->add_constraints();
|
||||
ct->add_enforcement_literal(index_lit);
|
||||
@@ -452,7 +512,12 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
context->UpdateRuleStats("element: expanded");
|
||||
|
||||
if (all_constants) {
|
||||
context->UpdateRuleStats("element: expanded value element");
|
||||
} else {
|
||||
context->UpdateRuleStats("element: expanded");
|
||||
}
|
||||
ct->Clear();
|
||||
}
|
||||
|
||||
|
||||
99
ortools/sat/presolve_util.cc
Normal file
99
ortools/sat/presolve_util.cc
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright 2010-2018 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/sat/presolve_util.h"
|
||||
|
||||
#include "ortools/base/map_util.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
|
||||
void DomainDeductions::AddDeduction(int literal_ref, int var, Domain domain) {
|
||||
CHECK_GE(var, 0);
|
||||
const Index index = IndexFromLiteral(literal_ref);
|
||||
if (index >= something_changed_.size()) {
|
||||
something_changed_.Resize(index + 1);
|
||||
enforcement_to_vars_.resize(index.value() + 1);
|
||||
}
|
||||
if (var >= tmp_num_occurences_.size()) {
|
||||
tmp_num_occurences_.resize(var + 1, 0);
|
||||
}
|
||||
const auto insert = deductions_.insert({{index, var}, domain});
|
||||
if (insert.second) {
|
||||
// New element.
|
||||
something_changed_.Set(index);
|
||||
enforcement_to_vars_[index].push_back(var);
|
||||
} else {
|
||||
// Existing element.
|
||||
const Domain& old_domain = insert.first->second;
|
||||
if (!old_domain.IsIncludedIn(domain)) {
|
||||
insert.first->second = domain.IntersectionWith(old_domain);
|
||||
something_changed_.Set(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<int, Domain>> DomainDeductions::ProcessClause(
|
||||
absl::Span<const int> clause) {
|
||||
std::vector<std::pair<int, Domain>> result;
|
||||
|
||||
// We only need to process this clause if something changed since last time.
|
||||
bool abort = true;
|
||||
for (const int ref : clause) {
|
||||
const Index index = IndexFromLiteral(ref);
|
||||
if (index >= something_changed_.size()) return result;
|
||||
if (something_changed_[index]) {
|
||||
abort = false;
|
||||
}
|
||||
}
|
||||
if (abort) return result;
|
||||
|
||||
// Count for each variable, how many times it appears in the deductions lists.
|
||||
std::vector<int> to_process;
|
||||
std::vector<int> to_clean;
|
||||
for (const int ref : clause) {
|
||||
const Index index = IndexFromLiteral(ref);
|
||||
for (const int var : enforcement_to_vars_[index]) {
|
||||
if (tmp_num_occurences_[var] == 0) {
|
||||
to_clean.push_back(var);
|
||||
}
|
||||
tmp_num_occurences_[var]++;
|
||||
if (tmp_num_occurences_[var] == clause.size()) {
|
||||
to_process.push_back(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the counts.
|
||||
for (const int var : to_clean) {
|
||||
tmp_num_occurences_[var] = 0;
|
||||
}
|
||||
|
||||
// Compute the domain unions.
|
||||
std::vector<Domain> domains(to_process.size());
|
||||
for (const int ref : clause) {
|
||||
const Index index = IndexFromLiteral(ref);
|
||||
for (int i = 0; i < to_process.size(); ++i) {
|
||||
domains[i] = domains[i].UnionWith(
|
||||
gtl::FindOrDieNoPrint(deductions_, {index, to_process[i]}));
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < to_process.size(); ++i) {
|
||||
result.push_back({to_process[i], std::move(domains[i])});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
80
ortools/sat/presolve_util.h
Normal file
80
ortools/sat/presolve_util.h
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright 2010-2018 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_SAT_PRESOLVE_UTIL_H_
|
||||
#define OR_TOOLS_SAT_PRESOLVE_UTIL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "ortools/base/int_type.h"
|
||||
#include "ortools/base/int_type_indexed_vector.h"
|
||||
#include "ortools/base/integral_types.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/util/bitset.h"
|
||||
#include "ortools/util/sorted_interval_list.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
|
||||
// If for each literal of a clause, we can infer a domain on an integer
|
||||
// variable, then we know that this variable domain is included in the union of
|
||||
// such infered domains.
|
||||
//
|
||||
// This allows to propagate "element" like constraints encoded as enforced
|
||||
// linear relations, and other more general reasoning.
|
||||
//
|
||||
// TODO(user): Also use these "deductions" in the solver directly. This is done
|
||||
// in good MIP solvers, and we should exploit them more.
|
||||
class DomainDeductions {
|
||||
public:
|
||||
// Adds the fact that enforcement => var \in domain.
|
||||
void AddDeduction(int literal_ref, int var, Domain domain);
|
||||
|
||||
// Returns list of (var, domain) that were deduced because:
|
||||
// 1/ We have a domain deduction for var and all literal from the clause
|
||||
// 2/ So we can take the union of all the deduced domains.
|
||||
//
|
||||
// TODO(user): We could probably be even more efficient. We could also
|
||||
// compute exactly what clauses need to be "waked up" as new deductions are
|
||||
// added.
|
||||
std::vector<std::pair<int, Domain>> ProcessClause(
|
||||
absl::Span<const int> clause);
|
||||
|
||||
// Optimization. Any following ProcessClause() will be fast if no more
|
||||
// deduction touching that clause are added.
|
||||
void MarkProcessingAsDoneForNow() {
|
||||
something_changed_.ClearAndResize(something_changed_.size());
|
||||
}
|
||||
|
||||
// Returns the total number of "deductions" stored by this class.
|
||||
int NumDeductions() const { return deductions_.size(); }
|
||||
|
||||
private:
|
||||
DEFINE_INT_TYPE(Index, int);
|
||||
Index IndexFromLiteral(int ref) {
|
||||
return Index(ref >= 0 ? 2 * ref : -2 * ref - 1);
|
||||
}
|
||||
|
||||
std::vector<int> tmp_num_occurences_;
|
||||
|
||||
SparseBitset<Index> something_changed_;
|
||||
gtl::ITIVector<Index, std::vector<int>> enforcement_to_vars_;
|
||||
absl::flat_hash_map<std::pair<Index, int>, Domain> deductions_;
|
||||
};
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
|
||||
#endif // OR_TOOLS_SAT_PRESOLVE_UTIL_H_
|
||||
Reference in New Issue
Block a user