[CP-SAT] morework on lrat
This commit is contained in:
committed by
Corentin Le Molgat
parent
69a94a445e
commit
3b18bdd58b
@@ -298,6 +298,29 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gate_utils",
|
||||
hdrs = ["gate_utils.h"],
|
||||
deps = [
|
||||
":sat_base",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "gate_utils_test",
|
||||
srcs = ["gate_utils_test.cc"],
|
||||
deps = [
|
||||
":gate_utils",
|
||||
":sat_base",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/base:logging",
|
||||
"@abseil-cpp//absl/random",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "cp_model_cc_proto",
|
||||
visibility = ["//visibility:public"],
|
||||
@@ -1585,6 +1608,7 @@ cc_library(
|
||||
hdrs = ["sat_inprocessing.h"],
|
||||
deps = [
|
||||
":clause",
|
||||
":gate_utils",
|
||||
":linear_programming_constraint",
|
||||
":lrat_proof_handler",
|
||||
":model",
|
||||
@@ -1939,7 +1963,9 @@ cc_library(
|
||||
hdrs = ["integer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":clause",
|
||||
":integer_base",
|
||||
":lrat_proof_handler",
|
||||
":model",
|
||||
":sat_base",
|
||||
":sat_parameters_cc_proto",
|
||||
@@ -4279,6 +4305,7 @@ cc_library(
|
||||
":recordio",
|
||||
":sat_base",
|
||||
":synchronization",
|
||||
":util",
|
||||
"//ortools/base:file",
|
||||
"//ortools/base:intops",
|
||||
"//ortools/base:timer",
|
||||
@@ -4291,6 +4318,19 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "lrat_proof_handler_test",
|
||||
srcs = ["lrat_proof_handler_test.cc"],
|
||||
deps = [
|
||||
":lrat_proof_handler",
|
||||
":model",
|
||||
":sat_base",
|
||||
":util",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "lrat_checker_test",
|
||||
srcs = ["lrat_checker_test.cc"],
|
||||
|
||||
@@ -404,7 +404,7 @@ bool AllDifferentConstraint::Propagate() {
|
||||
}
|
||||
}
|
||||
|
||||
return trail_->EnqueueWithStoredReason(x_lit.Negated());
|
||||
return trail_->EnqueueWithStoredReason(kNoClauseId, x_lit.Negated());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,7 +308,7 @@ bool CircuitPropagator::Propagate() {
|
||||
std::vector<Literal>* reason = trail_.GetEmptyVectorToStoreReason();
|
||||
FillReasonForPath(start_node, reason);
|
||||
enforcement_helper_.AddEnforcementReason(enforcement_id_, reason);
|
||||
if (!trail_.EnqueueWithStoredReason(literal.Negated())) {
|
||||
if (!trail_.EnqueueWithStoredReason(kNoClauseId, literal.Negated())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -357,7 +357,8 @@ bool CircuitPropagator::Propagate() {
|
||||
if (extra_reason != kFalseLiteralIndex) {
|
||||
reason->push_back(Literal(extra_reason));
|
||||
}
|
||||
const bool ok = trail_.EnqueueWithStoredReason(literal.Negated());
|
||||
const bool ok =
|
||||
trail_.EnqueueWithStoredReason(kNoClauseId, literal.Negated());
|
||||
if (!ok) return false;
|
||||
continue;
|
||||
}
|
||||
@@ -398,7 +399,7 @@ bool CircuitPropagator::Propagate() {
|
||||
std::vector<Literal>* reason = trail_.GetEmptyVectorToStoreReason();
|
||||
FillReasonForPath(start_node, reason);
|
||||
enforcement_helper_.AddEnforcementReason(enforcement_id_, reason);
|
||||
const bool ok = trail_.EnqueueWithStoredReason(literal);
|
||||
const bool ok = trail_.EnqueueWithStoredReason(kNoClauseId, literal);
|
||||
if (!ok) return false;
|
||||
} else {
|
||||
trail_.EnqueueWithSameReasonAs(literal, variable_with_same_reason);
|
||||
@@ -669,8 +670,8 @@ bool CircuitCoveringPropagator::Propagate() {
|
||||
!trail_->Assignment().LiteralIsFalse(graph_[end][start])) {
|
||||
auto* reason = trail_->GetEmptyVectorToStoreReason();
|
||||
FillFixedPathInReason(start, end, reason);
|
||||
const bool ok =
|
||||
trail_->EnqueueWithStoredReason(graph_[end][start].Negated());
|
||||
const bool ok = trail_->EnqueueWithStoredReason(
|
||||
kNoClauseId, graph_[end][start].Negated());
|
||||
if (!ok) return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,20 +477,7 @@ bool ClauseManager::InprocessingAddUnitClause(ClauseId unit_clause_id,
|
||||
|
||||
bool ClauseManager::InprocessingFixLiteral(
|
||||
Literal true_literal, absl::Span<const ClauseId> clause_ids) {
|
||||
DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
|
||||
if (trail_->Assignment().LiteralIsTrue(true_literal)) return true;
|
||||
|
||||
ClauseId clause_id = kNoClauseId;
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
clause_id = clause_id_generator_->GetNextId();
|
||||
lrat_proof_handler_->AddInferredClause(clause_id, {true_literal},
|
||||
clause_ids);
|
||||
}
|
||||
trail_->EnqueueWithUnitReason(clause_id, true_literal);
|
||||
|
||||
// Even when all clauses are detached, we can propagate the implication
|
||||
// graph and we do that right away.
|
||||
return implication_graph_->Propagate(trail_);
|
||||
return implication_graph_->FixLiteral(true_literal, clause_ids);
|
||||
}
|
||||
|
||||
void ClauseManager::ChangeLbdIfBetter(SatClause* clause, int new_lbd) {
|
||||
@@ -943,21 +930,17 @@ bool BinaryImplicationGraph::AddBinaryClauseInternal(
|
||||
ClauseId id, Literal a, Literal b, bool change_reason,
|
||||
bool delete_non_representative_id) {
|
||||
SCOPED_TIME_STAT(&stats_);
|
||||
DCHECK_GE(a.Index(), 0);
|
||||
DCHECK_GE(b.Index(), 0);
|
||||
|
||||
// Tricky: If this is the first clause, the propagator will be added and
|
||||
// assumed to be in a "propagated" state. This makes sure this is the case.
|
||||
if (no_constraint_ever_added_) propagation_trail_index_ = trail_->Index();
|
||||
no_constraint_ever_added_ = false;
|
||||
|
||||
Literal rep_a = a;
|
||||
Literal rep_b = b;
|
||||
ClauseId rep_id = kNoClauseId;
|
||||
if (is_redundant_[a.Index()]) {
|
||||
rep_a = Literal(representative_of_[a.Index()]);
|
||||
}
|
||||
if (is_redundant_[b.Index()]) {
|
||||
rep_b = Literal(representative_of_[b.Index()]);
|
||||
}
|
||||
const Literal rep_a = RepresentativeOf(a);
|
||||
const Literal rep_b = RepresentativeOf(b);
|
||||
if (rep_a == rep_b.Negated()) return true;
|
||||
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
@@ -1361,15 +1344,50 @@ void BinaryImplicationGraph::Reimply(Trail* trail, int old_trail_index) {
|
||||
// can take advantage of that.
|
||||
void BinaryImplicationGraph::MinimizeConflictFirst(
|
||||
const Trail& trail, std::vector<Literal>* conflict,
|
||||
SparseBitset<BooleanVariable>* marked, std::vector<ClauseId>* clause_ids) {
|
||||
SparseBitset<BooleanVariable>* marked, std::vector<ClauseId>* clause_ids,
|
||||
bool also_use_decisions) {
|
||||
SCOPED_TIME_STAT(&stats_);
|
||||
DCHECK(!conflict->empty());
|
||||
is_marked_.ClearAndResize(LiteralIndex(implications_and_amos_.size()));
|
||||
|
||||
tmp_to_keep_.clear();
|
||||
tmp_to_keep_.push_back(conflict->front().Negated());
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
MarkDescendants</*fill_implied_by=*/true>(conflict->front().Negated());
|
||||
} else {
|
||||
MarkDescendants(conflict->front().Negated());
|
||||
}
|
||||
|
||||
// Because the decision cannot be removed from the conflict, we know they will
|
||||
// stay, so it is okay to see what they propagate in the binary implication
|
||||
// graph. Technically we could do that for any first literal of a decision
|
||||
// level. Improve?
|
||||
std::vector<std::pair<Literal, int>> decisions;
|
||||
if (also_use_decisions) {
|
||||
for (int i = 1; i < conflict->size(); ++i) {
|
||||
const auto& info = trail_->Info((*conflict)[i].Variable());
|
||||
if (info.type == AssignmentType::kSearchDecision) {
|
||||
decisions.push_back({(*conflict)[i].Negated(), info.level});
|
||||
}
|
||||
}
|
||||
absl::c_stable_sort(decisions, [](const std::pair<LiteralIndex, int>& a,
|
||||
const std::pair<LiteralIndex, int>& b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
}
|
||||
|
||||
// Keep marking everything propagated by the decisions, and make sure we
|
||||
// don't remove the one from which we called MarkDescendants().
|
||||
for (const auto [literal, unused_level] : decisions) {
|
||||
if (is_marked_[literal]) continue;
|
||||
tmp_to_keep_.push_back(literal);
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
MarkDescendants</*fill_implied_by=*/true>(literal);
|
||||
} else {
|
||||
MarkDescendants(literal);
|
||||
}
|
||||
}
|
||||
|
||||
for (const LiteralIndex i : is_marked_.PositionsSetAtLeastOnce()) {
|
||||
// TODO(user): if this is false, then we actually have a conflict of size 2.
|
||||
// This can only happen if the binary clause was not propagated properly
|
||||
@@ -1379,6 +1397,9 @@ void BinaryImplicationGraph::MinimizeConflictFirst(
|
||||
marked->Set(Literal(i).Variable());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all marked literals from the conflict.
|
||||
for (const Literal l : tmp_to_keep_) is_marked_.Clear(l);
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
RemoveRedundantLiterals</*fill_clause_ids=*/true>(conflict, clause_ids);
|
||||
} else {
|
||||
|
||||
@@ -46,15 +46,6 @@
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
|
||||
// A generator for ClauseIds. Not thread-safe.
|
||||
class ClauseIdGenerator {
|
||||
public:
|
||||
ClauseId GetNextId() { return ClauseId(next_id_++); }
|
||||
|
||||
private:
|
||||
ClauseId next_id_ = ClauseId(1);
|
||||
};
|
||||
|
||||
// This is how the SatSolver stores a clause. A clause is just a disjunction of
|
||||
// literals. In many places, we just use vector<literal> to encode one. But in
|
||||
// the critical propagation code, we use this class to remove one memory
|
||||
@@ -709,7 +700,8 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
// details about the different algorithms.
|
||||
void MinimizeConflictFirst(const Trail& trail, std::vector<Literal>* c,
|
||||
SparseBitset<BooleanVariable>* marked,
|
||||
std::vector<ClauseId>* clause_ids);
|
||||
std::vector<ClauseId>* clause_ids,
|
||||
bool also_use_decisions);
|
||||
|
||||
// Appends the IDs of the unit and binary clauses that imply the given
|
||||
// literals to `clause_ids`. Either `MinimizeConflictFirst` or
|
||||
@@ -985,6 +977,12 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
return implications_and_amos_[lit].offsets();
|
||||
}
|
||||
|
||||
// Simple wrapper to not forget to output newly fixed variable to the DRAT
|
||||
// or LRAT proof (with clause_ids as proof) if needed. This will propagate
|
||||
// right away the implications.
|
||||
bool FixLiteral(Literal true_literal,
|
||||
absl::Span<const ClauseId> clause_ids = {});
|
||||
|
||||
private:
|
||||
friend class LratEquivalenceHelper;
|
||||
|
||||
@@ -1005,12 +1003,6 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
clause_id_[{a, b}] = id;
|
||||
}
|
||||
|
||||
// Simple wrapper to not forget to output newly fixed variable to the DRAT
|
||||
// or LRAT proof (with clause_ids as proof) if needed. This will propagate
|
||||
// right away the implications.
|
||||
bool FixLiteral(Literal true_literal,
|
||||
absl::Span<const ClauseId> clause_ids = {});
|
||||
|
||||
// Removes any literal whose negation is marked (except the first one). If
|
||||
// `fill_clause_ids` is true, fills the LRAT proof for this change in
|
||||
// `clause_ids` (this requires an LRAT proof handler to be set, and
|
||||
@@ -1117,12 +1109,14 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
int64_t num_redundant_literals_ = 0;
|
||||
|
||||
// Bitset used by MinimizeClause().
|
||||
//
|
||||
// TODO(user): use the same one as the one used in the classic minimization
|
||||
// because they are already initialized. Moreover they contains more
|
||||
// information.
|
||||
// information?
|
||||
SparseBitset<LiteralIndex> is_marked_;
|
||||
SparseBitset<LiteralIndex> tmp_bitset_;
|
||||
SparseBitset<LiteralIndex> is_simplified_;
|
||||
std::vector<Literal> tmp_to_keep_;
|
||||
|
||||
// Used by AppendImplicationChains() to avoid processing a unit clause several
|
||||
// times.
|
||||
|
||||
@@ -858,9 +858,6 @@ std::vector<SatParameters> GetFullWorkerParameters(
|
||||
ModelHasSchedulingConstraints(cp_model);
|
||||
|
||||
// Our current set of strategies
|
||||
//
|
||||
// TODO(user): Avoid launching two strategies if they are the same,
|
||||
// like if there is no lp, or everything is already linearized at level 1.
|
||||
std::vector<std::string> names;
|
||||
|
||||
// Starts by adding user specified ones.
|
||||
@@ -1021,7 +1018,15 @@ std::vector<SatParameters> GetFullWorkerParameters(
|
||||
|
||||
if (result.size() > num_to_keep) {
|
||||
result.resize(std::max(0, num_to_keep));
|
||||
} else if (!result.empty() && num_to_keep >= 0) {
|
||||
// If we have less parameters, duplicate the first one until we have enough.
|
||||
// This is a bit hacky but easily allow to do experiment with n times the
|
||||
// same subsolver.
|
||||
while (result.size() < num_to_keep) {
|
||||
result.push_back(result[0]);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
199
ortools/sat/gate_utils.h
Normal file
199
ortools/sat/gate_utils.h
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Functions to manipulate a "small" truth table where
|
||||
// f(X0, X1, X2) is true iff bitmask[X0 + (X1 << 1) + (X2 << 2)] is true.
|
||||
#ifndef ORTOOLS_SAT_GATE_UTILS_H_
|
||||
#define ORTOOLS_SAT_GATE_UTILS_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
|
||||
namespace operations_research::sat {
|
||||
|
||||
using SmallBitset = uint32_t;
|
||||
|
||||
// Sort the key and modify the truth table accordingly.
|
||||
//
|
||||
// Note that we don't deal with identical key here, but the function
|
||||
// CanonicalizeFunctionTruthTable() does, and that is sufficient for our use
|
||||
// case.
|
||||
template <typename VarOrLiteral>
|
||||
void CanonicalizeTruthTable(absl::Span<VarOrLiteral> key,
|
||||
SmallBitset& bitmask) {
|
||||
const int num_bits = key.size();
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0);
|
||||
for (int i = 0; i < num_bits; ++i) {
|
||||
for (int j = i + 1; j < num_bits; ++j) {
|
||||
if (key[i] <= key[j]) continue;
|
||||
|
||||
std::swap(key[i], key[j]);
|
||||
|
||||
// We need to swap bit positions i and j in bitmask.
|
||||
SmallBitset new_bitmask = 0;
|
||||
for (int p = 0; p < (1 << num_bits); ++p) {
|
||||
const int value_i = (p >> i) & 1;
|
||||
const int value_j = (p >> j) & 1;
|
||||
int new_p = p;
|
||||
new_p ^= (value_i << i) ^ (value_j << j); // Clear.
|
||||
new_p ^= (value_i << j) ^ (value_j << i); // Swap.
|
||||
new_bitmask |= ((bitmask >> p) & 1) << new_p;
|
||||
}
|
||||
bitmask = new_bitmask;
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0)
|
||||
<< i << " " << j << " " << num_bits;
|
||||
}
|
||||
}
|
||||
CHECK(std::is_sorted(key.begin(), key.end()));
|
||||
}
|
||||
|
||||
// Given a clause, return the truth table corresponding to it.
|
||||
// Namely, a single value should be excluded.
|
||||
inline void FillKeyAndBitmask(absl::Span<const Literal> clause,
|
||||
absl::Span<BooleanVariable> key,
|
||||
SmallBitset& bitmask) {
|
||||
CHECK_EQ(clause.size(), key.size());
|
||||
const int num_bits = clause.size();
|
||||
bitmask = ~SmallBitset(0) >> (32 - (1 << num_bits)); // All allowed;
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0) << num_bits;
|
||||
SmallBitset bit_to_remove = 0;
|
||||
for (int i = 0; i < num_bits; ++i) {
|
||||
key[i] = clause[i].Variable();
|
||||
bit_to_remove |= (clause[i].IsPositive() ? 0 : 1) << i;
|
||||
}
|
||||
bitmask ^= (SmallBitset(1) << bit_to_remove);
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0) << bit_to_remove << " " << num_bits;
|
||||
CanonicalizeTruthTable<BooleanVariable>(key, bitmask);
|
||||
}
|
||||
|
||||
// Returns true iff the truth table encoded in bitmask encode a function
|
||||
// Xi = f(Xj, j != i);
|
||||
template <int num_bits>
|
||||
bool IsFunction(int i, SmallBitset truth_table) {
|
||||
DCHECK_GE(i, 0);
|
||||
DCHECK_LT(i, num_bits);
|
||||
|
||||
// We need to check that there is never two possibilities for Xi.
|
||||
for (int p = 0; p < (1 << num_bits); ++p) {
|
||||
if ((truth_table >> p) & (truth_table >> (p ^ (1 << i))) & 1) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline int AddHoleAtPosition(int i, int bitset) {
|
||||
return (bitset & ((1 << i) - 1)) + ((bitset >> i) << (i + 1));
|
||||
}
|
||||
|
||||
// The function is target = function_values[inputs as bit position].
|
||||
//
|
||||
// TODO(user): This can be optimized with more bit twiddling if needed.
|
||||
inline int CanonicalizeFunctionTruthTable(LiteralIndex& target,
|
||||
absl::Span<LiteralIndex> inputs,
|
||||
int& int_function_values) {
|
||||
// We want to fit on an int.
|
||||
CHECK_LE(inputs.size(), 4);
|
||||
|
||||
// We assume smaller type.
|
||||
SmallBitset function_values = int_function_values;
|
||||
|
||||
const int num_bits = inputs.size();
|
||||
CHECK_LE(num_bits, 4); // Truth table must fit on an int.
|
||||
const SmallBitset all_one = (1 << (1 << num_bits)) - 1;
|
||||
CHECK_EQ(function_values & ~all_one, 0);
|
||||
|
||||
// Make sure target is positive.
|
||||
if (!Literal(target).IsPositive()) {
|
||||
target = Literal(target).Negated();
|
||||
function_values = function_values ^ all_one;
|
||||
CHECK_EQ(function_values >> (1 << num_bits), 0);
|
||||
}
|
||||
|
||||
// Make sure all inputs are positive.
|
||||
for (int i = 0; i < num_bits; ++i) {
|
||||
if (Literal(inputs[i]).IsPositive()) continue;
|
||||
|
||||
inputs[i] = Literal(inputs[i]).NegatedIndex();
|
||||
|
||||
// Position p go to position (p ^ (1 << i)).
|
||||
SmallBitset new_truth_table = 0;
|
||||
const SmallBitset to_xor = 1 << i;
|
||||
for (int p = 0; p < (1 << num_bits); ++p) {
|
||||
new_truth_table |= ((function_values >> p) & 1) << (p ^ to_xor);
|
||||
}
|
||||
function_values = new_truth_table;
|
||||
CHECK_EQ(function_values >> (1 << num_bits), 0);
|
||||
}
|
||||
|
||||
// Sort the inputs now.
|
||||
CanonicalizeTruthTable<LiteralIndex>(inputs, function_values);
|
||||
CHECK_EQ(function_values >> (1 << num_bits), 0);
|
||||
|
||||
// Merge identical variables.
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
for (int j = i + 1; j < inputs.size();) {
|
||||
if (inputs[i] == inputs[j]) {
|
||||
// Lets remove input j.
|
||||
for (int k = j; k + 1 < inputs.size(); ++k) inputs[k] = inputs[k + 1];
|
||||
inputs.remove_suffix(1);
|
||||
|
||||
SmallBitset new_truth_table = 0;
|
||||
for (int p = 0; p < (1 << inputs.size()); ++p) {
|
||||
int extended_p = AddHoleAtPosition(j, p);
|
||||
extended_p |= ((p >> i) & 1) << j; // fill it with bit i.
|
||||
new_truth_table |= ((function_values >> extended_p) & 1) << p;
|
||||
}
|
||||
function_values = new_truth_table;
|
||||
} else {
|
||||
++j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lower arity?
|
||||
// This can happen if the output do not depend on one of the inputs.
|
||||
for (int i = 0; i < inputs.size();) {
|
||||
bool remove = true;
|
||||
for (int p = 0; p < (1 << inputs.size()); ++p) {
|
||||
if (((function_values >> p) & 1) !=
|
||||
((function_values >> (p ^ (1 << i))) & 1)) {
|
||||
remove = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (remove) {
|
||||
// Lets remove input i.
|
||||
for (int k = i; k + 1 < inputs.size(); ++k) inputs[k] = inputs[k + 1];
|
||||
inputs.remove_suffix(1);
|
||||
|
||||
SmallBitset new_truth_table = 0;
|
||||
for (int p = 0; p < (1 << inputs.size()); ++p) {
|
||||
const int extended_p = AddHoleAtPosition(i, p);
|
||||
new_truth_table |= ((function_values >> extended_p) & 1) << p;
|
||||
}
|
||||
function_values = new_truth_table;
|
||||
} else {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
int_function_values = function_values;
|
||||
return inputs.size();
|
||||
}
|
||||
|
||||
} // namespace operations_research::sat
|
||||
|
||||
#endif // ORTOOLS_SAT_GATE_UTILS_H_
|
||||
147
ortools/sat/gate_utils_test.cc
Normal file
147
ortools/sat/gate_utils_test.cc
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/sat/gate_utils.h"
|
||||
|
||||
#include <array>
|
||||
#include <bitset>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/random/random.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
|
||||
namespace operations_research::sat {
|
||||
namespace {
|
||||
|
||||
TEST(CanonicalizeTruthTableTest, BasicBehavior1) {
|
||||
std::array<int, 3> key = {0, 2, 1};
|
||||
|
||||
// no change here.
|
||||
SmallBitset bitmask = 0b10101010;
|
||||
CanonicalizeTruthTable<int>(absl::MakeSpan(key), bitmask);
|
||||
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b10101010));
|
||||
}
|
||||
|
||||
TEST(CanonicalizeTruthTableTest, BasicBehavior2) {
|
||||
std::array<int, 3> key = {2, 0, 1};
|
||||
SmallBitset bitmask = 0b10101010;
|
||||
CanonicalizeTruthTable<int>(absl::MakeSpan(key), bitmask);
|
||||
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11110000));
|
||||
}
|
||||
|
||||
TEST(CanonicalizeTruthTableTest, BasicBehavior3) {
|
||||
std::array<int, 3> key = {1, 0, 2};
|
||||
SmallBitset bitmask = 0b10101010;
|
||||
CanonicalizeTruthTable<int>(absl::MakeSpan(key), bitmask);
|
||||
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11001100));
|
||||
}
|
||||
|
||||
TEST(FillKeyAndBitmaskTest, BasicBehavior1) {
|
||||
std::array<BooleanVariable, 3> key;
|
||||
SmallBitset bitmask;
|
||||
FillKeyAndBitmask({Literal(+1), Literal(-2), Literal(+3)},
|
||||
absl::MakeSpan(key), bitmask);
|
||||
EXPECT_THAT(key,
|
||||
::testing::ElementsAre(BooleanVariable(0), BooleanVariable(1),
|
||||
BooleanVariable(2)));
|
||||
// The bit number 2 = 0b010 should be off.
|
||||
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11111011));
|
||||
}
|
||||
|
||||
TEST(IsFunctionTest, ConstantValue) {
|
||||
EXPECT_TRUE(IsFunction<3>(0, 0b10101010));
|
||||
EXPECT_FALSE(IsFunction<3>(1, 0b10101010));
|
||||
EXPECT_FALSE(IsFunction<3>(2, 0b10101010));
|
||||
}
|
||||
|
||||
TEST(AddHoleAtPositionTest, BasicTest) {
|
||||
EXPECT_EQ(AddHoleAtPosition(0, 0xFF), 0b111111110);
|
||||
EXPECT_EQ(AddHoleAtPosition(1, 0xFF), 0b111111101);
|
||||
EXPECT_EQ(AddHoleAtPosition(8, 0xFF), 0b011111111);
|
||||
}
|
||||
|
||||
TEST(CanonicalizeFunctionTruthTableTest, RandomTest) {
|
||||
absl::BitGen random;
|
||||
const int num_vars = 8;
|
||||
|
||||
for (int num_test = 0; num_test < 1000; ++num_test) {
|
||||
// Lets generate a random function on k random variables.
|
||||
const int k = absl::Uniform(random, 0, 4);
|
||||
const int table = absl::Uniform<uint64_t>(random, 0, 1 << (1 << k));
|
||||
const Literal output(BooleanVariable(100), absl::Bernoulli(random, 0.5));
|
||||
std::vector<Literal> inputs;
|
||||
for (int i = 0; i < k; ++i) {
|
||||
inputs.push_back(
|
||||
Literal(BooleanVariable(absl::Uniform(random, 0, num_vars)),
|
||||
absl::Bernoulli(random, 0.5)));
|
||||
}
|
||||
|
||||
LiteralIndex new_output_index = output.Index();
|
||||
std::vector<Literal> new_inputs = inputs;
|
||||
std::vector<LiteralIndex> new_inputs_index;
|
||||
for (const Literal lit : new_inputs) {
|
||||
new_inputs_index.push_back(lit.Index());
|
||||
}
|
||||
int new_table = table;
|
||||
const int new_size = CanonicalizeFunctionTruthTable(
|
||||
new_output_index, absl::MakeSpan(new_inputs_index), new_table);
|
||||
new_inputs_index.resize(new_size);
|
||||
new_inputs.clear();
|
||||
for (const LiteralIndex lit_index : new_inputs_index) {
|
||||
new_inputs.push_back(Literal(lit_index));
|
||||
}
|
||||
|
||||
// Log before potential failure.
|
||||
LOG(INFO) << "IN arity=" << k << " " << output << " = f(" << inputs << ") "
|
||||
<< std::bitset<16>(table);
|
||||
LOG(INFO) << "OUT arity=" << new_size << " " << Literal(new_output_index)
|
||||
<< " = f(" << new_inputs << ") " << std::bitset<16>(new_table);
|
||||
|
||||
// Now check that both function always take the same value.
|
||||
for (int m = 0; m < (1 << num_vars); ++m) {
|
||||
int index = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
const Literal lit = inputs[i];
|
||||
int value = (m >> lit.Variable().value()) & 1;
|
||||
if (!lit.IsPositive()) value = 1 - value;
|
||||
index |= value << i;
|
||||
}
|
||||
const int target_value = (table >> index) & 1;
|
||||
|
||||
int new_index = 0;
|
||||
for (int i = 0; i < new_inputs.size(); ++i) {
|
||||
const Literal lit = new_inputs[i];
|
||||
int value = (m >> lit.Variable().value()) & 1;
|
||||
if (!lit.IsPositive()) value = 1 - value;
|
||||
new_index |= value << i;
|
||||
}
|
||||
const int new_target_value = (new_table >> new_index) & 1;
|
||||
|
||||
if (output == Literal(new_output_index)) {
|
||||
ASSERT_EQ(target_value, new_target_value) << index << " " << new_index;
|
||||
} else {
|
||||
ASSERT_EQ(output, Literal(new_output_index).Negated());
|
||||
ASSERT_EQ(target_value, 1 - new_target_value)
|
||||
<< index << " " << new_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::sat
|
||||
@@ -21,6 +21,7 @@
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -34,7 +35,9 @@
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/strong_vector.h"
|
||||
#include "ortools/sat/clause.h"
|
||||
#include "ortools/sat/integer_base.h"
|
||||
#include "ortools/sat/lrat_proof_handler.h"
|
||||
#include "ortools/sat/model.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
#include "ortools/sat/sat_parameters.pb.h"
|
||||
@@ -105,6 +108,8 @@ class IntegerEncoder {
|
||||
explicit IntegerEncoder(Model* model)
|
||||
: sat_solver_(model->GetOrCreate<SatSolver>()),
|
||||
trail_(model->GetOrCreate<Trail>()),
|
||||
clause_id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
|
||||
lrat_proof_handler_(model->Mutable<LratProofHandler>()),
|
||||
delayed_to_fix_(model->GetOrCreate<DelayedRootLevelDeduction>()),
|
||||
domains_(*model->GetOrCreate<IntegerDomains>()),
|
||||
num_created_variables_(0) {}
|
||||
@@ -291,9 +296,15 @@ class IntegerEncoder {
|
||||
Literal(sat_solver_->NewBooleanVariable(), true);
|
||||
literal_index_true_ = literal_true.Index();
|
||||
|
||||
// This might return false if we are already UNSAT.
|
||||
// TODO(user): Make sure we abort right away on unsat!
|
||||
(void)sat_solver_->AddUnitClause(literal_true);
|
||||
ClauseId clause_id = kNoClauseId;
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
clause_id = clause_id_generator_->GetNextId();
|
||||
// We cannot prove `literal_true` by unit propagation, but we can with a
|
||||
// RAT inference (trivial here since there are no clauses containing the
|
||||
// negation of the pivot `literal_true`).
|
||||
lrat_proof_handler_->AddInferredClause(clause_id, {literal_true}, {});
|
||||
}
|
||||
trail_->EnqueueWithUnitReason(clause_id, literal_true);
|
||||
}
|
||||
return Literal(literal_index_true_);
|
||||
}
|
||||
@@ -328,6 +339,8 @@ class IntegerEncoder {
|
||||
|
||||
SatSolver* sat_solver_;
|
||||
Trail* trail_;
|
||||
ClauseIdGenerator* clause_id_generator_;
|
||||
LratProofHandler* lrat_proof_handler_;
|
||||
DelayedRootLevelDeduction* delayed_to_fix_;
|
||||
const IntegerDomains& domains_;
|
||||
|
||||
@@ -339,7 +352,7 @@ class IntegerEncoder {
|
||||
// corresponding to the same variable).
|
||||
//
|
||||
// Note that we only keep this for positive variable.
|
||||
// The one for the negation can be infered by it.
|
||||
// The one for the negation can be inferred by it.
|
||||
//
|
||||
// Like x >= 1 x >= 4 x >= 5
|
||||
// Correspond to x <= 0 x <= 3 x <= 4
|
||||
|
||||
@@ -512,8 +512,8 @@ void IntegerConflictResolution::ComputeFirstUIPConflict(
|
||||
}
|
||||
|
||||
if (uip_found) {
|
||||
if (params_.binary_minimization_algorithm() ==
|
||||
SatParameters::BINARY_MINIMIZATION_FIRST) {
|
||||
if (params_.binary_minimization_algorithm() !=
|
||||
SatParameters::NO_BINARY_MINIMIZATION) {
|
||||
if (conflict->empty()) {
|
||||
// This one will always stay in the conflict, even after
|
||||
// minimization. So we can use it to minimize the conflict and avoid
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ortools/sat/lrat_proof_handler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <bitset>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
@@ -23,6 +24,7 @@
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
@@ -40,6 +42,7 @@
|
||||
#include "ortools/sat/recordio.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
#include "ortools/sat/synchronization.h"
|
||||
#include "ortools/sat/util.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
ABSL_FLAG(std::string, cp_model_drat_output, ".\\drat.txt",
|
||||
@@ -423,6 +426,7 @@ std::unique_ptr<LratProofHandler> LratProofHandler::MaybeCreate(Model* model) {
|
||||
|
||||
LratProofHandler::LratProofHandler(Model* model)
|
||||
: id_(model->GetOrCreate<SharedLratProofStatus>()->NewSubSolverId()),
|
||||
id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
|
||||
proof_status_(model->GetOrCreate<SharedLratProofStatus>()) {
|
||||
const SatParameters& params = *model->GetOrCreate<SatParameters>();
|
||||
if (params.check_lrat_proof()) {
|
||||
@@ -448,7 +452,7 @@ LratProofHandler::LratProofHandler(Model* model)
|
||||
|
||||
bool LratProofHandler::AddProblemClause(ClauseId id,
|
||||
absl::Span<const Literal> clause) {
|
||||
VLOG(1) << "AddProblemClause: id=" << id
|
||||
VLOG(2) << "AddProblemClause: id=" << id
|
||||
<< " literals=" << absl::StrJoin(clause, ",");
|
||||
if (all_problem_clauses_loaded_ && debug_crash_on_error_) {
|
||||
LOG(FATAL) << "LRAT error: problem clauses must not be added after "
|
||||
@@ -481,7 +485,7 @@ bool LratProofHandler::AddInferredClause(
|
||||
ClauseId id, absl::Span<const Literal> clause,
|
||||
absl::Span<const ClauseId> unit_ids,
|
||||
absl::Span<const LratChecker::RatIds> rat) {
|
||||
VLOG(1) << "AddInferredClause: id=" << id
|
||||
VLOG(2) << "AddInferredClause: id=" << id
|
||||
<< " literals=" << absl::StrJoin(clause, ",")
|
||||
<< " unit_ids=" << absl::StrJoin(unit_ids, ",") << " rat={"
|
||||
<< absl::StrJoin(rat, " ") << "}";
|
||||
@@ -508,7 +512,7 @@ bool LratProofHandler::AddInferredClause(
|
||||
|
||||
bool LratProofHandler::AddImportedClause(ClauseId id,
|
||||
absl::Span<const Literal> clause) {
|
||||
VLOG(1) << "AddImportedClause: id=" << id
|
||||
VLOG(2) << "AddImportedClause: id=" << id
|
||||
<< " literals=" << absl::StrJoin(clause, ",");
|
||||
if (lrat_checker_ != nullptr &&
|
||||
!lrat_checker_->AddProblemClause(id, clause)) {
|
||||
@@ -526,7 +530,7 @@ bool LratProofHandler::AddImportedClause(ClauseId id,
|
||||
|
||||
bool LratProofHandler::AddAssumedClause(ClauseId id,
|
||||
absl::Span<const Literal> clause) {
|
||||
VLOG(1) << "AddAssumedClause: id=" << id
|
||||
VLOG(2) << "AddAssumedClause: id=" << id
|
||||
<< " literals=" << absl::StrJoin(clause, ",");
|
||||
if (debug_crash_on_error_) {
|
||||
LOG(FATAL) << "LRAT error: assumed clauses are not supposed to happen";
|
||||
@@ -571,7 +575,7 @@ void LratProofHandler::DeleteClause(ClauseId id,
|
||||
delete_pinned_clause_ = true;
|
||||
return;
|
||||
}
|
||||
VLOG(1) << "DeleteClause: id=" << id
|
||||
VLOG(2) << "DeleteClause: id=" << id
|
||||
<< " literals=" << absl::StrJoin(clause, ",");
|
||||
if (lrat_checker_ != nullptr) {
|
||||
lrat_checker_->DeleteClauses({id});
|
||||
@@ -638,5 +642,177 @@ void LratProofHandler::Close(bool model_is_unsat) {
|
||||
}
|
||||
}
|
||||
|
||||
bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
|
||||
ClauseId new_id, absl::Span<const Literal> new_clause,
|
||||
absl::Span<const ClauseId> ids_for_proof,
|
||||
const CompactVectorVector<int, Literal>& clauses_for_proof) {
|
||||
CHECK_EQ(ids_for_proof.size(), clauses_for_proof.size());
|
||||
CHECK(!clauses_for_proof.empty());
|
||||
|
||||
// First we count the number of variables appearing and have a separate dense
|
||||
// index for them. The first new_clause.size() dense index are exactly the
|
||||
// literal of the new_clause.
|
||||
absl::flat_hash_map<BooleanVariable, int> to_dense_index;
|
||||
std::vector<Literal> dense_index_to_literal;
|
||||
dense_index_to_literal.assign(new_clause.begin(), new_clause.end());
|
||||
for (const Literal lit : new_clause) {
|
||||
const auto [it, inserted] =
|
||||
to_dense_index.insert({lit.Variable(), to_dense_index.size()});
|
||||
if (!inserted) {
|
||||
VLOG(2) << "Duplicate variable in new_clause! " << new_clause;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Then any new BooleanVariable appearing get the next dense index.
|
||||
std::vector<Literal> relevant_literals;
|
||||
for (int i = 0; i < clauses_for_proof.size(); ++i) {
|
||||
for (const Literal lit : clauses_for_proof[i]) {
|
||||
const auto [it, inserted] =
|
||||
to_dense_index.insert({lit.Variable(), to_dense_index.size()});
|
||||
if (inserted) {
|
||||
relevant_literals.push_back(lit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Too many variables.
|
||||
//
|
||||
// TODO(user): The limit can be increased a bit if needed.
|
||||
if (to_dense_index.size() > 6) {
|
||||
VLOG(2) << "Too many variables";
|
||||
return false;
|
||||
}
|
||||
|
||||
// For the proof we will need all clauses of the form
|
||||
// {new_clause, l0, ..., lk} for all k in [0, n) and
|
||||
// li = relevant_literals[i] OR relevant_literals[i].Negated().
|
||||
//
|
||||
// That give us 2^(n + 1) intermediate clauses.
|
||||
// Their ids will be stored in (1 << k) + binary_encoding_of_the_li.
|
||||
const int n = to_dense_index.size() - new_clause.size();
|
||||
CHECK_EQ(n, relevant_literals.size());
|
||||
const int num_intermediates = 1 << (n + 1);
|
||||
std::vector<ClauseId> ids(num_intermediates, kNoClauseId);
|
||||
|
||||
if (n == 0) {
|
||||
VLOG(2) << "Nothing to prove! An existing clause is included inside";
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(2) << "Starting proof n= " << n << " " << num_intermediates;
|
||||
|
||||
// Any initial clause can be used to prove all the intermediates that contains
|
||||
// it. Note that this code supports duplicate literals in the clauses.
|
||||
for (int i = 0; i < clauses_for_proof.size(); ++i) {
|
||||
bool skip = false;
|
||||
int base_index = 0;
|
||||
int mask = 0;
|
||||
int k = 0;
|
||||
for (const Literal lit : clauses_for_proof[i]) {
|
||||
const int dense_index = to_dense_index[lit.Variable()];
|
||||
if (dense_index < new_clause.size()) {
|
||||
// Check that the literal is the same as in the new_clause, if
|
||||
// not, this clause will not be needed for the proof.
|
||||
if (lit != new_clause[dense_index]) {
|
||||
skip = true;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
k = std::max(k, dense_index);
|
||||
mask |= 1 << dense_index;
|
||||
if (lit == relevant_literals[dense_index - new_clause.size()]) {
|
||||
base_index |= 1 << dense_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (skip) continue;
|
||||
|
||||
mask >>= new_clause.size();
|
||||
base_index >>= new_clause.size();
|
||||
k = k + 1 - new_clause.size();
|
||||
|
||||
VLOG(2) << k << " " << std::bitset<8>(mask) << " "
|
||||
<< std::bitset<8>(base_index);
|
||||
|
||||
// TODO(user): we could be faster here if it become needed.
|
||||
for (int m = 0; m < (1 << n); ++m) {
|
||||
if ((m & mask) != base_index) continue; // not included.
|
||||
const int index = m | base_index;
|
||||
for (int j = k; j <= n; ++j) {
|
||||
if (index >> j == 0) {
|
||||
VLOG(2) << "Included in " << j << " "
|
||||
<< std::bitset<8>((1 << j) | index);
|
||||
ids[(1 << j) | index] = ids_for_proof[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We can prove the others by decreasing k.
|
||||
std::vector<Literal> tmp_clause;
|
||||
tmp_clause.assign(new_clause.begin(), new_clause.end());
|
||||
std::vector<bool> id_need_deletion(num_intermediates, false);
|
||||
for (int k = n; --k >= 0;) {
|
||||
for (int m = 0; m < (1 << k); ++m) {
|
||||
const int index = (1 << k) | m;
|
||||
if (ids[index] != kNoClauseId) continue; // Already proven.
|
||||
|
||||
// Generate the tmp_clause.
|
||||
tmp_clause.resize(new_clause.size());
|
||||
for (int i = 0; i < k; ++i) {
|
||||
tmp_clause.push_back(relevant_literals[i]);
|
||||
if (((index >> i) & 1) == 0) {
|
||||
tmp_clause.back() = tmp_clause.back().Negated();
|
||||
}
|
||||
}
|
||||
|
||||
// Prove it from the two clauses at k + 1.
|
||||
const int higher1 = index ^ ((0b11) << k);
|
||||
const int higher2 = index ^ ((0b10) << k);
|
||||
const ClauseId id1 = ids[higher1];
|
||||
const ClauseId id2 = ids[higher2];
|
||||
if (id1 == kNoClauseId || id2 == kNoClauseId) {
|
||||
VLOG(2) << "missing higher level clauses in the resolution."
|
||||
<< " index: " << std::bitset<8>(index)
|
||||
<< " higher1: " << std::bitset<8>(higher1)
|
||||
<< " higher2: " << std::bitset<8>(higher2);
|
||||
return false;
|
||||
}
|
||||
|
||||
ids[index] = k == 0 ? new_id : id_generator_->GetNextId();
|
||||
if (k != 0) {
|
||||
VLOG(2) << "temporary !! " << ids[index] << " " << tmp_clause;
|
||||
id_need_deletion[index] = true; // temporary.
|
||||
}
|
||||
if (!AddInferredClause(ids[index], tmp_clause, {id1, id2})) {
|
||||
VLOG(2) << "Failed resolution step";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (k == 0) {
|
||||
DCHECK_EQ(new_clause, tmp_clause);
|
||||
VLOG(2) << "Proven " << new_clause << "!";
|
||||
}
|
||||
|
||||
// Lets delete the ids if they were temporary.
|
||||
if (id_need_deletion[higher1]) {
|
||||
tmp_clause.push_back(relevant_literals[k].Negated());
|
||||
VLOG(2) << "deleting: " << id1 << " " << tmp_clause;
|
||||
DeleteClause(id1, tmp_clause);
|
||||
tmp_clause.pop_back();
|
||||
}
|
||||
if (id_need_deletion[higher2]) {
|
||||
tmp_clause.push_back(relevant_literals[k]);
|
||||
VLOG(2) << "deleting: " << id2 << " " << tmp_clause;
|
||||
DeleteClause(id2, tmp_clause);
|
||||
tmp_clause.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include "ortools/sat/recordio.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
#include "ortools/sat/synchronization.h"
|
||||
#include "ortools/sat/util.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
@@ -149,6 +150,19 @@ class LratProofHandler {
|
||||
absl::Span<const ClauseId> unit_ids,
|
||||
absl::Span<const LratChecker::RatIds> rat = {});
|
||||
|
||||
// This assumes that the 'new_clause' to prove and all the ones needed for the
|
||||
// proof only touch a small number of variables (<= 6). It will then prove the
|
||||
// new clause by enumerating all possibilities and producing the relevant
|
||||
// intermediate LRAT RUP steps.
|
||||
//
|
||||
// The last two arguments must have the same size and are in one to one
|
||||
// correspondence. Note that we might not need all the given clauses in the
|
||||
// proof.
|
||||
bool AddAndProveInferredClauseByEnumeration(
|
||||
ClauseId new_id, absl::Span<const Literal> new_clause,
|
||||
absl::Span<const ClauseId> ids_for_proof,
|
||||
const CompactVectorVector<int, Literal>& clauses_for_proof);
|
||||
|
||||
// Adds a clause which was inferred by another worker. Returns true if
|
||||
// successful (the operation can fail if LRAT checks are enabled, and the ID
|
||||
// is already used by another clause).
|
||||
@@ -190,6 +204,7 @@ class LratProofHandler {
|
||||
bool LratError() const;
|
||||
|
||||
const int id_;
|
||||
ClauseIdGenerator* id_generator_;
|
||||
SharedLratProofStatus* proof_status_;
|
||||
std::unique_ptr<LratChecker> lrat_checker_;
|
||||
std::unique_ptr<LratWriter> lrat_writer_;
|
||||
|
||||
76
ortools/sat/lrat_proof_handler_test.cc
Normal file
76
ortools/sat/lrat_proof_handler_test.cc
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/sat/lrat_proof_handler.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/sat/model.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
#include "ortools/sat/util.h"
|
||||
|
||||
namespace operations_research::sat {
|
||||
namespace {
|
||||
|
||||
TEST(AddAndProveInferredClauseByEnumerationTest, XorEquivalence) {
|
||||
// We assume a = XOR(c,d) and b = XOR(c,d) and we want to prove a => b
|
||||
const Literal a(+1);
|
||||
const Literal b(+2);
|
||||
const Literal c(+3);
|
||||
const Literal d(+4);
|
||||
|
||||
// Encode the two XOR gates.
|
||||
CompactVectorVector<int, Literal> clauses;
|
||||
for (const Literal x : {a, b}) {
|
||||
clauses.Add({c, d, x.Negated()});
|
||||
clauses.Add({c.Negated(), d, x});
|
||||
clauses.Add({c, d.Negated(), x});
|
||||
clauses.Add({c.Negated(), d.Negated(), x.Negated()});
|
||||
}
|
||||
|
||||
// Create the lrat checker.
|
||||
Model model;
|
||||
auto* params = model.GetOrCreate<SatParameters>();
|
||||
params->set_check_lrat_proof(true);
|
||||
params->set_check_drat_proof(true);
|
||||
std::unique_ptr<LratProofHandler> lrat =
|
||||
LratProofHandler::MaybeCreate(&model);
|
||||
|
||||
// Lets create ids for all these clauses.
|
||||
auto* id_generator = model.GetOrCreate<ClauseIdGenerator>();
|
||||
std::vector<ClauseId> clause_ids;
|
||||
for (int i = 0; i < clauses.size(); ++i) {
|
||||
const ClauseId id = id_generator->GetNextId();
|
||||
clause_ids.push_back(id);
|
||||
lrat->AddProblemClause(id, clauses[i]);
|
||||
}
|
||||
lrat->EndProblemClauses();
|
||||
|
||||
// This should be enough to prove equivalence.
|
||||
{
|
||||
std::vector<Literal> to_prove = {b.Negated(), a};
|
||||
EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration(
|
||||
id_generator->GetNextId(), to_prove, clause_ids, clauses));
|
||||
}
|
||||
{
|
||||
std::vector<Literal> to_prove = {a.Negated(), b};
|
||||
EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration(
|
||||
id_generator->GetNextId(), to_prove, clause_ids, clauses));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::sat
|
||||
@@ -242,9 +242,17 @@ class AssignmentView {
|
||||
|
||||
// The ID of a clause.
|
||||
DEFINE_STRONG_INT_TYPE(ClauseId, int64_t);
|
||||
|
||||
constexpr ClauseId kNoClauseId(0);
|
||||
|
||||
// A generator for ClauseIds. Not thread-safe.
|
||||
class ClauseIdGenerator {
|
||||
public:
|
||||
ClauseId GetNextId() { return ClauseId(next_id_++); }
|
||||
|
||||
private:
|
||||
ClauseId next_id_ = ClauseId(1);
|
||||
};
|
||||
|
||||
// Forward declaration.
|
||||
class SatClause;
|
||||
class SatPropagator;
|
||||
@@ -346,12 +354,12 @@ class Trail {
|
||||
EnqueueHelper(
|
||||
Literal* trail_ptr, AssignmentInfo* current_info,
|
||||
AssignmentInfo* info_ptr, VariablesAssignment* assignment,
|
||||
util_intops::StrongVector<BooleanVariable, ClauseId>* unit_clause_id)
|
||||
util_intops::StrongVector<BooleanVariable, ClauseId>* clause_ids)
|
||||
: trail_ptr_(trail_ptr),
|
||||
current_info_(current_info),
|
||||
info_ptr_(info_ptr),
|
||||
bitset_(assignment->GetBitsetView()),
|
||||
unit_clause_id_(unit_clause_id) {}
|
||||
clause_ids_(clause_ids) {}
|
||||
|
||||
void EnqueueAtLevel(Literal true_literal, int level) {
|
||||
bitset_.Set(true_literal);
|
||||
@@ -364,10 +372,10 @@ class Trail {
|
||||
void EnqueueWithUnitReason(Literal true_literal, ClauseId clause_id) {
|
||||
DCHECK_NE(clause_id, kNoClauseId);
|
||||
const BooleanVariable var = true_literal.Variable();
|
||||
if (var.value() >= unit_clause_id_->size()) {
|
||||
unit_clause_id_->resize(var.value() + 1, kNoClauseId);
|
||||
if (var.value() >= clause_ids_->size()) {
|
||||
clause_ids_->resize(var.value() + 1, kNoClauseId);
|
||||
}
|
||||
(*unit_clause_id_)[var] = clause_id;
|
||||
(*clause_ids_)[var] = clause_id;
|
||||
|
||||
bitset_.Set(true_literal);
|
||||
AssignmentInfo* info = info_ptr_ + true_literal.Variable().value();
|
||||
@@ -389,12 +397,12 @@ class Trail {
|
||||
AssignmentInfo* current_info_;
|
||||
AssignmentInfo* info_ptr_;
|
||||
Bitset64<LiteralIndex>::View bitset_;
|
||||
util_intops::StrongVector<BooleanVariable, ClauseId>* unit_clause_id_;
|
||||
util_intops::StrongVector<BooleanVariable, ClauseId>* clause_ids_;
|
||||
};
|
||||
EnqueueHelper GetEnqueueHelper(int propagator_id) {
|
||||
current_info_.type = propagator_id;
|
||||
return EnqueueHelper(trail_.data(), ¤t_info_, info_.data(),
|
||||
&assignment_, &unit_clause_id_);
|
||||
&assignment_, &clause_ids_);
|
||||
}
|
||||
|
||||
// Specific Enqueue() for search decisions.
|
||||
@@ -435,13 +443,7 @@ class Trail {
|
||||
|
||||
// Specific Enqueue() version for unit clauses.
|
||||
void EnqueueWithUnitReason(ClauseId clause_id, Literal true_literal) {
|
||||
if (clause_id != kNoClauseId) {
|
||||
const BooleanVariable var = true_literal.Variable();
|
||||
if (var.value() >= unit_clause_id_.size()) {
|
||||
unit_clause_id_.resize(var.value() + 1, kNoClauseId);
|
||||
}
|
||||
unit_clause_id_[var] = clause_id;
|
||||
}
|
||||
MaybeSetClauseId(true_literal, clause_id);
|
||||
EnqueueAtLevel(true_literal, AssignmentType::kUnitReason, 0);
|
||||
}
|
||||
|
||||
@@ -459,18 +461,21 @@ class Trail {
|
||||
|
||||
// Enqueues the given literal using the current content of
|
||||
// GetEmptyVectorToStoreReason() as the reason. This API is a bit more
|
||||
// leanient and does not require the literal to be unassigned. If it is
|
||||
// lenient and does not require the literal to be unassigned. If it is
|
||||
// already assigned to false, then MutableConflict() will be set appropriately
|
||||
// and this will return false otherwise this will enqueue the literal and
|
||||
// returns true.
|
||||
ABSL_MUST_USE_RESULT bool EnqueueWithStoredReason(Literal true_literal) {
|
||||
ABSL_MUST_USE_RESULT bool EnqueueWithStoredReason(ClauseId clause_id,
|
||||
Literal true_literal) {
|
||||
if (assignment_.LiteralIsTrue(true_literal)) return true;
|
||||
if (assignment_.LiteralIsFalse(true_literal)) {
|
||||
*MutableConflict() = reasons_repository_[Index()];
|
||||
MutableConflict()->push_back(true_literal);
|
||||
failing_clause_id_ = clause_id;
|
||||
return false;
|
||||
}
|
||||
|
||||
MaybeSetClauseId(true_literal, clause_id);
|
||||
Enqueue(true_literal, AssignmentType::kCachedReason);
|
||||
const BooleanVariable var = true_literal.Variable();
|
||||
reasons_[var] = reasons_repository_[info_[var].trail_index];
|
||||
@@ -516,8 +521,17 @@ class Trail {
|
||||
ClauseId GetUnitClauseId(BooleanVariable var) const {
|
||||
DCHECK(AssignmentType(var) == AssignmentType::kUnitReason);
|
||||
DCHECK_EQ(Info(var).level, 0);
|
||||
if (var.value() >= unit_clause_id_.size()) return kNoClauseId;
|
||||
return unit_clause_id_[var];
|
||||
if (var.value() >= clause_ids_.size()) return kNoClauseId;
|
||||
return clause_ids_[var];
|
||||
}
|
||||
|
||||
// Returns the ID of the clause which is the reason why the given variable was
|
||||
// enqueued, or kNoClauseId if there is none. The variable must have been
|
||||
// enqueued with EnqueueWithStoredReason().
|
||||
ClauseId GetStoredReasonClauseId(BooleanVariable var) const {
|
||||
DCHECK(AssignmentType(var) == AssignmentType::kCachedReason);
|
||||
if (var.value() >= clause_ids_.size()) return kNoClauseId;
|
||||
return clause_ids_[var];
|
||||
}
|
||||
|
||||
// If a variable was propagated with EnqueueWithSameReasonAs(), returns its
|
||||
@@ -583,6 +597,7 @@ class Trail {
|
||||
std::vector<Literal>* MutableConflict() {
|
||||
++conflict_timestamp_;
|
||||
failing_sat_clause_ = nullptr;
|
||||
failing_clause_id_ = kNoClauseId;
|
||||
return &conflict_;
|
||||
}
|
||||
|
||||
@@ -600,9 +615,16 @@ class Trail {
|
||||
// Specific SatClause interface so we can update the conflict clause activity.
|
||||
// Note that MutableConflict() automatically sets this to nullptr, so we can
|
||||
// know whether or not the last conflict was caused by a clause.
|
||||
void SetFailingSatClause(SatClause* clause) { failing_sat_clause_ = clause; }
|
||||
void SetFailingSatClause(SatClause* clause) {
|
||||
failing_sat_clause_ = clause;
|
||||
failing_clause_id_ = kNoClauseId;
|
||||
}
|
||||
SatClause* FailingSatClause() const { return failing_sat_clause_; }
|
||||
|
||||
// Returns the LRAT ID of the failing clause. This ID is only set if a
|
||||
// conflict is detected in EnqueueWithStoredReason().
|
||||
ClauseId FailingClauseId() const { return failing_clause_id_; }
|
||||
|
||||
// Getters.
|
||||
int NumVariables() const { return trail_.size(); }
|
||||
int64_t NumberOfEnqueues() const { return num_untrailed_enqueues_ + Index(); }
|
||||
@@ -664,6 +686,16 @@ class Trail {
|
||||
private:
|
||||
ConflictResolutionFunction resolution_;
|
||||
|
||||
void MaybeSetClauseId(Literal true_literal, ClauseId clause_id) {
|
||||
if (clause_id != kNoClauseId) {
|
||||
const BooleanVariable var = true_literal.Variable();
|
||||
if (var.value() >= clause_ids_.size()) {
|
||||
clause_ids_.resize(var.value() + 1, kNoClauseId);
|
||||
}
|
||||
clause_ids_[var] = clause_id;
|
||||
}
|
||||
}
|
||||
|
||||
// Finds all literals between the current trail index and the given one
|
||||
// assigned at the current level or lower, and re-enqueues them with the same
|
||||
// reason.
|
||||
@@ -678,9 +710,11 @@ class Trail {
|
||||
int64_t conflict_timestamp_ = 0;
|
||||
std::vector<Literal> conflict_;
|
||||
util_intops::StrongVector<BooleanVariable, AssignmentInfo> info_;
|
||||
// The ID of unit clauses (literals enqueued at level 0 with a kUnitReason).
|
||||
util_intops::StrongVector<BooleanVariable, ClauseId> unit_clause_id_;
|
||||
// The ID of unit clauses (literals enqueued at level 0 with a kUnitReason)
|
||||
// and of reason clauses for literals enqueued with a stored reason.
|
||||
util_intops::StrongVector<BooleanVariable, ClauseId> clause_ids_;
|
||||
SatClause* failing_sat_clause_;
|
||||
ClauseId failing_clause_id_;
|
||||
|
||||
// Data used by EnqueueWithSameReasonAs().
|
||||
util_intops::StrongVector<BooleanVariable, BooleanVariable>
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ortools/sat/sat_inprocessing.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
@@ -37,6 +38,7 @@
|
||||
#include "ortools/base/timer.h"
|
||||
#include "ortools/graph/connected_components.h"
|
||||
#include "ortools/sat/clause.h"
|
||||
#include "ortools/sat/gate_utils.h"
|
||||
#include "ortools/sat/linear_programming_constraint.h"
|
||||
#include "ortools/sat/lrat_proof_handler.h"
|
||||
#include "ortools/sat/probing.h"
|
||||
@@ -1916,14 +1918,46 @@ GateCongruenceClosure::~GateCongruenceClosure() {
|
||||
});
|
||||
}
|
||||
|
||||
template <int arity>
|
||||
void GateCongruenceClosure::AddToTruthTable(
|
||||
absl::Span<const Literal> clause,
|
||||
absl::flat_hash_map<std::array<BooleanVariable, arity>, SmallBitset>&
|
||||
data) {
|
||||
CHECK_EQ(clause.size(), arity);
|
||||
std::array<BooleanVariable, arity> key;
|
||||
SmallBitset bitmask;
|
||||
FillKeyAndBitmask(clause, absl::MakeSpan(key), bitmask);
|
||||
for (const BooleanVariable var : key) {
|
||||
CHECK(!implication_graph_->IsRemoved(Literal(var, true)));
|
||||
}
|
||||
auto [it, inserted] = data.insert({key, bitmask});
|
||||
if (!inserted) {
|
||||
const SmallBitset old = it->second;
|
||||
it->second &= bitmask; // Remove one value.
|
||||
if (old != it->second) {
|
||||
// TODO(user): keep id for proof.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that this is the "hot" part of the algo, once we have the and gates,
|
||||
// the congruence closure should be quite fast.
|
||||
void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) {
|
||||
void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
|
||||
PresolveTimer& timer) {
|
||||
truth_table3_.clear();
|
||||
truth_table4_.clear();
|
||||
|
||||
std::vector<Literal> candidates;
|
||||
for (SatClause* clause : clause_manager_->AllClausesInCreationOrder()) {
|
||||
if (timer.WorkLimitIsReached()) break;
|
||||
if (clause->size() == 0) continue;
|
||||
|
||||
if (clause->size() == 3) {
|
||||
AddToTruthTable<3>(clause->AsSpan(), truth_table3_);
|
||||
} else if (clause->size() == 4) {
|
||||
AddToTruthTable<4>(clause->AsSpan(), truth_table4_);
|
||||
}
|
||||
|
||||
// Used for an optimization below.
|
||||
int min_num_implications = std::numeric_limits<int>::max();
|
||||
Literal lit_with_less_implications;
|
||||
@@ -2030,6 +2064,8 @@ void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) {
|
||||
// Add the detected gate (its inputs are the negation of each clause
|
||||
// literal other than the target).
|
||||
gates_target_.push_back(target.Index());
|
||||
gates_type_.push_back(kAndGateType);
|
||||
|
||||
const int index = gates_inputs_.Add({});
|
||||
for (const Literal l : clause->AsSpan()) {
|
||||
if (l == target) continue;
|
||||
@@ -2048,6 +2084,60 @@ void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) {
|
||||
// a single base clause of size n will correspond to n and_gates !
|
||||
}
|
||||
}
|
||||
|
||||
timer.AddCounter("and_gates", gates_inputs_.size());
|
||||
}
|
||||
|
||||
template <int arity>
|
||||
void GateCongruenceClosure::ExtractShortGates(
|
||||
PresolveTimer& timer,
|
||||
const absl::flat_hash_map<std::array<BooleanVariable, arity>, SmallBitset>&
|
||||
data) {
|
||||
// For a table on n variables, we look for function x = f(n - 1) variable.
|
||||
const int num_bits = arity - 1;
|
||||
|
||||
// TODO(user): This is non-deterministic order. We need to fix that or
|
||||
// initially sort the queue of gates to process.
|
||||
int num_functions = 0;
|
||||
for (const auto [key, truth_table] : data) {
|
||||
for (int i = 0; i < arity; ++i) {
|
||||
if (!IsFunction<arity>(i, truth_table)) continue;
|
||||
++num_functions;
|
||||
|
||||
gates_target_.push_back(Literal(key[i], true));
|
||||
gates_inputs_.Add({});
|
||||
for (int j = 0; j < arity; ++j) {
|
||||
if (i != j) {
|
||||
gates_inputs_.AppendToLastVector(Literal(key[j], true));
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the function truth table as a type.
|
||||
// We will canonicalize it further in the main loop.
|
||||
unsigned int type = 0;
|
||||
for (int p = 0; p < (1 << num_bits); ++p) {
|
||||
// Expand from (arity - 1) bits to (arity) bits.
|
||||
const int bigger_p = AddHoleAtPosition(i, p);
|
||||
|
||||
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
|
||||
// target is 1 at this position.
|
||||
type |= 1 << p;
|
||||
DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function.
|
||||
} else {
|
||||
// Note that if there is no feasible assignment for a given p, first
|
||||
// we could have learned a smaller clause, but also we don't really
|
||||
// care what is the value of the target at that point p, so we use
|
||||
// zero.
|
||||
}
|
||||
}
|
||||
|
||||
gates_type_.push_back(type);
|
||||
gates_clause_.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
timer.AddCounter(absl::StrCat("table", arity), data.size());
|
||||
timer.AddCounter(absl::StrCat("fn", num_bits), num_functions);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -2290,17 +2380,27 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
|
||||
gates_target_.clear();
|
||||
gates_inputs_.clear();
|
||||
gates_type_.clear();
|
||||
gates_clause_.clear();
|
||||
|
||||
// TODO(user): Extract more gates, and encode store their type.
|
||||
ExtractAndGates(timer);
|
||||
timer.AddCounter("and_gates", gates_inputs_.size());
|
||||
ExtractAndGatesAndFillShortTruthTables(timer);
|
||||
|
||||
// TODO(user): We currently do not support this with lrat. Fix.
|
||||
if (lrat_proof_handler_ == nullptr) {
|
||||
ExtractShortGates<3>(timer, truth_table3_);
|
||||
ExtractShortGates<4>(timer, truth_table4_);
|
||||
}
|
||||
|
||||
// All vector have the same size.
|
||||
// Except gates_clause_ which is only filled if we need proof.
|
||||
CHECK_EQ(gates_target_.size(), gates_type_.size());
|
||||
CHECK_EQ(gates_target_.size(), gates_inputs_.size());
|
||||
|
||||
// If two gates have the same type and the same inputs, their targets are
|
||||
// equivalent. We use an hash set to detect that the inputs are the same.
|
||||
absl::flat_hash_set<int, GateHash, GateEq> gate_set(
|
||||
/*capacity=*/gates_inputs_.size(), GateHash(&gates_inputs_),
|
||||
GateEq(&gates_inputs_));
|
||||
/*capacity=*/gates_inputs_.size(), GateHash(&gates_type_, &gates_inputs_),
|
||||
GateEq(&gates_type_, &gates_inputs_));
|
||||
|
||||
// Used to find representatives as we detect equivalent literal.
|
||||
DenseConnectedComponentsFinder union_find;
|
||||
@@ -2330,6 +2430,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
int num_units = 0;
|
||||
int num_processed = 0;
|
||||
int num_equivalences = 0;
|
||||
int arity1_equivalences = 0;
|
||||
while (!queue.empty()) {
|
||||
++num_processed;
|
||||
const int id = queue.back();
|
||||
@@ -2368,13 +2469,13 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
if (is_equivalent) {
|
||||
CHECK_NE(id, other_id);
|
||||
CHECK_GE(other_id, 0);
|
||||
CHECK_EQ(gates_type_[id], gates_type_[other_id]);
|
||||
CHECK_EQ(absl::Span<const LiteralIndex>(gates_inputs_[id]),
|
||||
absl::Span<const LiteralIndex>(gates_inputs_[other_id]));
|
||||
|
||||
// We detected a <=> b (or, equivalently, rep(a) <=> rep(b)).
|
||||
const LiteralIndex a = gates_target_[id];
|
||||
const LiteralIndex b = gates_target_[other_id];
|
||||
|
||||
input_literals_to_gate.RemoveFromFutureOutput(id);
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
lrat_helper.ShortenEquivalencesWithRepresentative(Literal(a));
|
||||
@@ -2382,6 +2483,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
}
|
||||
const LiteralIndex rep_a(union_find.FindRoot(a.value()));
|
||||
const LiteralIndex rep_b(union_find.FindRoot(b.value()));
|
||||
|
||||
if (rep_a != rep_b) {
|
||||
++num_equivalences;
|
||||
const Literal rep_lit_a(rep_a);
|
||||
@@ -2401,45 +2503,58 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
return false;
|
||||
}
|
||||
|
||||
union_find.AddEdge(rep_a.value(), rep_b.value());
|
||||
const LiteralIndex rep(union_find.FindRoot(rep_b.value()));
|
||||
const LiteralIndex to_merge = rep == rep_a ? rep_b : rep_a;
|
||||
input_literals_to_gate.MergeInto(to_merge, rep);
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
if (rep == rep_a) {
|
||||
lrat_helper.AddGateEquivalenceClauses(
|
||||
rep_lit_b, rep_b_implies_rep_a, rep_a_implies_rep_b);
|
||||
} else {
|
||||
lrat_helper.AddGateEquivalenceClauses(
|
||||
rep_lit_a, rep_a_implies_rep_b, rep_b_implies_rep_a);
|
||||
}
|
||||
}
|
||||
for (const bool negate : {false, true}) {
|
||||
const LiteralIndex x =
|
||||
negate ? Literal(rep_a).NegatedIndex() : rep_a;
|
||||
const LiteralIndex y =
|
||||
negate ? Literal(rep_b).NegatedIndex() : rep_b;
|
||||
|
||||
// Re-add to the queue all gates with touched inputs.
|
||||
//
|
||||
// TODO(user): I think we could only add the gates of "to_merge"
|
||||
// before we merge. This part of the code is quite quick in any case.
|
||||
for (const int gate_id : input_literals_to_gate[rep]) {
|
||||
if (in_queue[gate_id]) continue;
|
||||
queue.push_back(gate_id);
|
||||
in_queue[gate_id] = true;
|
||||
union_find.AddEdge(x.value(), y.value());
|
||||
const LiteralIndex rep(union_find.FindRoot(y.value()));
|
||||
const LiteralIndex to_merge = rep == x ? y : x;
|
||||
input_literals_to_gate.MergeInto(to_merge, rep);
|
||||
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
if (rep == x) {
|
||||
lrat_helper.AddGateEquivalenceClauses(
|
||||
Literal(y),
|
||||
negate ? rep_a_implies_rep_b : rep_b_implies_rep_a,
|
||||
negate ? rep_b_implies_rep_a : rep_a_implies_rep_b);
|
||||
} else {
|
||||
lrat_helper.AddGateEquivalenceClauses(
|
||||
Literal(x),
|
||||
negate ? rep_b_implies_rep_a : rep_a_implies_rep_b,
|
||||
negate ? rep_a_implies_rep_b : rep_b_implies_rep_a);
|
||||
}
|
||||
}
|
||||
|
||||
// Re-add to the queue all gates with touched inputs.
|
||||
//
|
||||
// TODO(user): I think we could only add the gates of "to_merge"
|
||||
// before we merge. This part of the code is quite quick in any
|
||||
// case.
|
||||
for (const int gate_id : input_literals_to_gate[rep]) {
|
||||
if (in_queue[gate_id]) continue;
|
||||
queue.push_back(gate_id);
|
||||
in_queue[gate_id] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Canonicalize.
|
||||
// Canonicalize on pass zero, otherwise loop.
|
||||
// Note that even if nothing change, we want to run canonicalize at
|
||||
// least once on the small "truth table" gates.
|
||||
//
|
||||
// Note that sorting works for and_gate and any gate that does not depend
|
||||
// on the order of its inputs. But if we add more fancy functions, we will
|
||||
// need to be careful.
|
||||
//
|
||||
// TODO(user): Because we fix literal, we should also deal with fixed
|
||||
// literals here. Right now we defer this to a later step, but it might
|
||||
// reduce the cascading effect of finding more equivalences.
|
||||
if (pass == 0) {
|
||||
marked_.ResetAllToFalse();
|
||||
if (pass > 0) continue;
|
||||
|
||||
if (gates_type_[id] == kAndGateType) {
|
||||
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
|
||||
marked_.ResetAllToFalse();
|
||||
int new_size = 0;
|
||||
bool is_unit = false;
|
||||
for (const LiteralIndex l : inputs) {
|
||||
@@ -2454,13 +2569,16 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
if (marked_[Literal(rep).Negated()]) {
|
||||
is_unit = true;
|
||||
const Literal to_fix = Literal(gates_target_[id]).Negated();
|
||||
absl::InlinedVector<ClauseId, 4> clause_ids;
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
lrat_helper.AppendFixAndGateTargetClauses(id, Literal(rep),
|
||||
clause_ids);
|
||||
}
|
||||
if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) {
|
||||
return false;
|
||||
if (!assignment_.LiteralIsTrue(to_fix)) {
|
||||
absl::InlinedVector<ClauseId, 4> clause_ids;
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
lrat_helper.AppendFixAndGateTargetClauses(id, Literal(rep),
|
||||
clause_ids);
|
||||
}
|
||||
if (!clause_manager_->InprocessingFixLiteral(to_fix,
|
||||
clause_ids)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -2480,6 +2598,50 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
CHECK_GT(new_size, 0);
|
||||
std::sort(inputs.begin(), inputs.begin() + new_size);
|
||||
gates_inputs_.Shrink(id, new_size);
|
||||
|
||||
// Lets convert to the short "type" if we can. The truth table is simply
|
||||
// a 1 on the last position (where all inputs are ones). We fall back to
|
||||
// the case below to canonicalize further.
|
||||
if (new_size > 4 || lrat_proof_handler_ != nullptr) continue;
|
||||
gates_type_[id] = 1 << ((1 << new_size) - 1);
|
||||
}
|
||||
|
||||
// Generic "short" gates.
|
||||
// We just take the representative and re-canonicalize.
|
||||
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
|
||||
CHECK_GE(gates_type_[id], 0);
|
||||
CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
|
||||
for (LiteralIndex& lit_ref : inputs) {
|
||||
lit_ref = LiteralIndex(union_find.FindRoot(lit_ref.value()));
|
||||
}
|
||||
|
||||
const int new_size = CanonicalizeFunctionTruthTable(
|
||||
gates_target_[id], inputs, gates_type_[id]);
|
||||
if (new_size < inputs.size()) {
|
||||
gates_inputs_.Shrink(id, new_size);
|
||||
}
|
||||
|
||||
if (new_size == 1) {
|
||||
// We have a function of size 1! this is an equivalence.
|
||||
//
|
||||
// TODO(user): deal with it.
|
||||
++arity1_equivalences;
|
||||
input_literals_to_gate.RemoveFromFutureOutput(id);
|
||||
break;
|
||||
} else if (new_size == 0) {
|
||||
// We have a fixed function ! just fix the literal.
|
||||
CHECK(Literal(gates_target_[id]).IsPositive());
|
||||
const Literal to_fix{Literal(gates_target_[id]).Variable(),
|
||||
(gates_type_[id] & 1) == 1};
|
||||
if (!assignment_.LiteralIsTrue(to_fix)) {
|
||||
absl::InlinedVector<ClauseId, 4> clause_ids;
|
||||
if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) {
|
||||
return false;
|
||||
}
|
||||
++num_units;
|
||||
}
|
||||
input_literals_to_gate.RemoveFromFutureOutput(id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2490,6 +2652,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
total_equivalences_ += num_equivalences;
|
||||
total_num_units_ += num_units;
|
||||
|
||||
timer.AddCounter("arity1_equivalences", arity1_equivalences);
|
||||
timer.AddCounter("units", num_units);
|
||||
timer.AddCounter("processed", num_processed);
|
||||
timer.AddCounter("equivalences", num_equivalences);
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/strong_vector.h"
|
||||
#include "ortools/sat/clause.h"
|
||||
#include "ortools/sat/gate_utils.h"
|
||||
#include "ortools/sat/linear_programming_constraint.h"
|
||||
#include "ortools/sat/lrat_proof_handler.h"
|
||||
#include "ortools/sat/model.h"
|
||||
@@ -439,7 +440,8 @@ class BoundedVariableElimination {
|
||||
class GateCongruenceClosure {
|
||||
public:
|
||||
explicit GateCongruenceClosure(Model* model)
|
||||
: sat_solver_(model->GetOrCreate<SatSolver>()),
|
||||
: assignment_(model->GetOrCreate<Trail>()->Assignment()),
|
||||
sat_solver_(model->GetOrCreate<SatSolver>()),
|
||||
implication_graph_(model->GetOrCreate<BinaryImplicationGraph>()),
|
||||
clause_manager_(model->GetOrCreate<ClauseManager>()),
|
||||
clause_id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
|
||||
@@ -454,23 +456,29 @@ class GateCongruenceClosure {
|
||||
|
||||
private:
|
||||
struct GateHash {
|
||||
explicit GateHash(CompactVectorVector<int, LiteralIndex>* g)
|
||||
: gates_inputs(g) {}
|
||||
explicit GateHash(std::vector<int>* t,
|
||||
CompactVectorVector<int, LiteralIndex>* g)
|
||||
: gates_type(t), gates_inputs(g) {}
|
||||
std::size_t operator()(int gate_id) const {
|
||||
return absl::HashOf((*gates_inputs)[gate_id]);
|
||||
return absl::HashOf((*gates_type)[gate_id], (*gates_inputs)[gate_id]);
|
||||
}
|
||||
CompactVectorVector<int, LiteralIndex>* gates_inputs;
|
||||
const std::vector<int>* gates_type;
|
||||
const CompactVectorVector<int, LiteralIndex>* gates_inputs;
|
||||
};
|
||||
|
||||
struct GateEq {
|
||||
explicit GateEq(CompactVectorVector<int, LiteralIndex>* g)
|
||||
: gates_inputs(g) {}
|
||||
explicit GateEq(std::vector<int>* t,
|
||||
CompactVectorVector<int, LiteralIndex>* g)
|
||||
: gates_type(t), gates_inputs(g) {}
|
||||
bool operator()(int gate_a, int gate_b) const {
|
||||
if (gate_a == gate_b) return true;
|
||||
|
||||
// We use absl::span<> comparison.
|
||||
return (*gates_inputs)[gate_a] == (*gates_inputs)[gate_b];
|
||||
return ((*gates_type)[gate_a] == (*gates_type)[gate_b]) &&
|
||||
((*gates_inputs)[gate_a] == (*gates_inputs)[gate_b]);
|
||||
}
|
||||
CompactVectorVector<int, LiteralIndex>* gates_inputs;
|
||||
const std::vector<int>* gates_type;
|
||||
const CompactVectorVector<int, LiteralIndex>* gates_inputs;
|
||||
};
|
||||
|
||||
// Recovers "target_literal = and(literals)" from the model.
|
||||
@@ -483,8 +491,20 @@ class GateCongruenceClosure {
|
||||
// - for all i, target_literal => literal_i (direct binary implication)
|
||||
// - all literal at true => target_literal, this is a clause:
|
||||
// (not(literal[i]) for all i, target_literal).
|
||||
void ExtractAndGates(PresolveTimer& timer);
|
||||
void ExtractAndGatesAndFillShortTruthTables(PresolveTimer& timer);
|
||||
|
||||
// From possible assignment of "arity" given variables, extract functions.
|
||||
template <int arity>
|
||||
void ExtractShortGates(
|
||||
PresolveTimer& timer,
|
||||
const absl::flat_hash_map<std::array<BooleanVariable, arity>,
|
||||
SmallBitset>& data);
|
||||
template <int arity>
|
||||
void AddToTruthTable(absl::Span<const Literal> clause,
|
||||
absl::flat_hash_map<std::array<BooleanVariable, arity>,
|
||||
SmallBitset>& data);
|
||||
|
||||
const VariablesAssignment& assignment_;
|
||||
SatSolver* sat_solver_;
|
||||
BinaryImplicationGraph* implication_graph_;
|
||||
ClauseManager* clause_manager_;
|
||||
@@ -501,17 +521,33 @@ class GateCongruenceClosure {
|
||||
// A Boolean gates correspond to target = f(inputs).
|
||||
//
|
||||
// Note that the inputs are canonicalized. For and_gates, they are sorted,
|
||||
// since the gate function does not depend on the order.
|
||||
// since the gate function does not depend on the order. The type of an
|
||||
// and_gates is kAndGateType.
|
||||
//
|
||||
// TODO(user): for now we have a single gate type. We can easily support more
|
||||
// by creating an extra std::vector<GateType> and adding that to our
|
||||
// GateHash/GateEq hash_set.
|
||||
// Otherwise, we support generic 2 and 3 inputs gates where the type is the
|
||||
// truth table. i.e. target = type[sum value_of_inputs[i] * 2^i]. For such
|
||||
// gate, the target and inputs will always be canonicalized to positive and
|
||||
// sorted literal. We just update the truth table accordingly.
|
||||
static constexpr int kAndGateType = -1;
|
||||
std::vector<LiteralIndex> gates_target_;
|
||||
std::vector<int> gates_type_;
|
||||
CompactVectorVector<int, LiteralIndex> gates_inputs_;
|
||||
|
||||
// For each gate, "the" corresponding clause. For a gate a = and(x,y,...) this
|
||||
// is the clause "x and y and ... => a". Only used for LRAT.
|
||||
std::vector<const SatClause*> gates_clause_;
|
||||
|
||||
// Map (Xi) (sorted) to a bitmask corresponding to the allowed values.
|
||||
// We loop over all short clauses to fill this.
|
||||
//
|
||||
// TODO(user): Shorter clauses impact larger truth table too and we can
|
||||
// combine two size 3 to construct a size 4 (needed for ITE-gate).
|
||||
// not ideal.
|
||||
absl::flat_hash_map<std::array<BooleanVariable, 3>, SmallBitset>
|
||||
truth_table3_;
|
||||
absl::flat_hash_map<std::array<BooleanVariable, 4>, SmallBitset>
|
||||
truth_table4_;
|
||||
|
||||
// For stats.
|
||||
double total_dtime_ = 0.0;
|
||||
double total_wtime_ = 0.0;
|
||||
|
||||
@@ -127,10 +127,11 @@ message SatParameters {
|
||||
enum BinaryMinizationAlgorithm {
|
||||
reserved 2, 3, 4;
|
||||
NO_BINARY_MINIMIZATION = 0;
|
||||
BINARY_MINIMIZATION_FIRST = 1;
|
||||
BINARY_MINIMIZATION_FROM_UIP = 1;
|
||||
BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS = 5;
|
||||
}
|
||||
optional BinaryMinizationAlgorithm binary_minimization_algorithm = 34
|
||||
[default = BINARY_MINIMIZATION_FIRST];
|
||||
[default = BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS];
|
||||
|
||||
// At a really low cost, during the 1-UIP conflict computation, it is easy to
|
||||
// detect if some of the involved reasons are subsumed by the current
|
||||
|
||||
@@ -974,10 +974,16 @@ void SatSolver::ProcessCurrentConflict(
|
||||
switch (parameters_->binary_minimization_algorithm()) {
|
||||
case SatParameters::NO_BINARY_MINIMIZATION:
|
||||
break;
|
||||
case SatParameters::BINARY_MINIMIZATION_FIRST:
|
||||
case SatParameters::BINARY_MINIMIZATION_FROM_UIP:
|
||||
binary_implication_graph_->MinimizeConflictFirst(
|
||||
*trail_, &learned_conflict_, &is_marked_,
|
||||
clause_ids_for_minimization);
|
||||
clause_ids_for_minimization, /*also_use_decisions=*/false);
|
||||
break;
|
||||
case SatParameters::BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS:
|
||||
binary_implication_graph_->MinimizeConflictFirst(
|
||||
*trail_, &learned_conflict_, &is_marked_,
|
||||
clause_ids_for_minimization, /*also_use_decisions=*/true);
|
||||
break;
|
||||
}
|
||||
DCHECK(IsConflictValid(learned_conflict_));
|
||||
}
|
||||
@@ -1060,7 +1066,7 @@ void SatSolver::ProcessCurrentConflict(
|
||||
for (const Literal l : clause) {
|
||||
if (Assignment().LiteralIsFalse(l)) ++num_false;
|
||||
}
|
||||
if (num_false == clause.size()) {
|
||||
if (num_false == clause.size() || clause.size() == 1) {
|
||||
int max_level = 0;
|
||||
for (const Literal l : clause) {
|
||||
const int level = AssignmentLevel(l.Variable());
|
||||
@@ -1366,6 +1372,8 @@ void SatSolver::AppendLratProofForFailingClause(
|
||||
const SatClause* failing_sat_clause = trail_->FailingSatClause();
|
||||
if (failing_sat_clause != nullptr) {
|
||||
failing_clause_id = clauses_propagator_->GetClauseId(failing_sat_clause);
|
||||
} else if (trail_->FailingClauseId() != kNoClauseId) {
|
||||
failing_clause_id = trail_->FailingClauseId();
|
||||
} else {
|
||||
absl::Span<const Literal> failing_clause = trail_->FailingClause();
|
||||
if (failing_clause.size() == 2) {
|
||||
@@ -2342,6 +2350,7 @@ std::string SatSolver::RunningStatisticsString() const {
|
||||
void SatSolver::ProcessNewlyFixedVariables() {
|
||||
SCOPED_TIME_STAT(&stats_);
|
||||
DCHECK_EQ(CurrentDecisionLevel(), 0);
|
||||
if (num_processed_fixed_variables_ == trail_->Index()) return;
|
||||
int num_detached_clauses = 0;
|
||||
int num_binary = 0;
|
||||
|
||||
|
||||
@@ -91,17 +91,19 @@ void SharedStatTables::AddSearchStat(absl::string_view name, Model* model) {
|
||||
void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) {
|
||||
absl::MutexLock mutex_lock(mutex_);
|
||||
auto* sat_solver = model->GetOrCreate<SatSolver>();
|
||||
auto* binary = model->GetOrCreate<BinaryImplicationGraph>();
|
||||
SatSolver::Counters counters = sat_solver->counters();
|
||||
|
||||
if (clauses_table_.empty()) {
|
||||
clauses_table_.push_back({"SAT stats", "ClassicMinim", "LitRemoved",
|
||||
"LitLearned", "LitForgotten", "Subsumed",
|
||||
"MClauses", "MDecisions", "MLitTrue", "MSubsumed",
|
||||
"MLitRemoved", "MReused"});
|
||||
"LitRemovedBinary", "LitLearned", "LitForgotten",
|
||||
"Subsumed", "MClauses", "MDecisions", "MLitTrue",
|
||||
"MSubsumed", "MLitRemoved", "MReused"});
|
||||
}
|
||||
clauses_table_.push_back(
|
||||
{FormatName(name), FormatCounter(counters.num_minimizations),
|
||||
FormatCounter(counters.num_literals_removed),
|
||||
FormatCounter(binary->num_literals_removed()),
|
||||
FormatCounter(counters.num_literals_learned),
|
||||
FormatCounter(counters.num_literals_forgotten),
|
||||
FormatCounter(counters.num_subsumed_clauses),
|
||||
@@ -117,7 +119,6 @@ void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) {
|
||||
bool_var_table_.push_back(
|
||||
{"Boolean variables", "Fixed", "Equiv", "Total", "Left", "Binary"});
|
||||
}
|
||||
auto* binary = model->GetOrCreate<BinaryImplicationGraph>();
|
||||
const int64_t num_fixed = sat_solver->NumFixedVariables();
|
||||
const int64_t num_equiv = binary->num_redundant_literals() / 2;
|
||||
const int64_t num_bools = sat_solver->NumVariables();
|
||||
|
||||
@@ -607,33 +607,6 @@ void MaxBoundedSubsetSum::AddChoicesInternal(absl::Span<const int64_t> values) {
|
||||
}
|
||||
}
|
||||
|
||||
int64_t MaxBoundedSubsetSum::MaxIfAdded(int64_t candidate) const {
|
||||
if (candidate > bound_ || current_max_ == bound_) return current_max_;
|
||||
|
||||
int64_t current_max = current_max_;
|
||||
// Mode 1: vector of all possible sums (with duplicates).
|
||||
if (!sums_.empty()) {
|
||||
for (const int64_t v : sums_) {
|
||||
if (v + candidate > bound_) continue;
|
||||
if (v + candidate > current_max) {
|
||||
current_max = v + candidate;
|
||||
if (current_max == bound_) return current_max;
|
||||
}
|
||||
}
|
||||
return current_max;
|
||||
}
|
||||
|
||||
// Mode 2: bitset of all possible sums.
|
||||
if (!expanded_sums_.empty()) {
|
||||
const int64_t min_useful = std::max<int64_t>(0, current_max_ - candidate);
|
||||
const int64_t max_useful = bound_ - candidate;
|
||||
for (int64_t v = max_useful; v >= min_useful; --v) {
|
||||
if (expanded_sums_[v]) return v + candidate;
|
||||
}
|
||||
}
|
||||
return current_max_;
|
||||
}
|
||||
|
||||
BasicKnapsackSolver::Result BasicKnapsackSolver::Solve(
|
||||
absl::Span<const Domain> domains, absl::Span<const int64_t> coeffs,
|
||||
absl::Span<const int64_t> costs, const Domain& rhs) {
|
||||
|
||||
@@ -550,9 +550,6 @@ class MaxBoundedSubsetSum {
|
||||
// We look for the maximum sum <= bound.
|
||||
void Reset(int64_t bound);
|
||||
|
||||
// Returns the updated max if value was added to the subset-sum.
|
||||
int64_t MaxIfAdded(int64_t candidate) const;
|
||||
|
||||
// Add a value to the base set for which subset sums will be taken.
|
||||
void Add(int64_t value);
|
||||
|
||||
|
||||
@@ -737,21 +737,6 @@ TEST(MaxBoundedSubsetSumTest, SimpleMultiChoice) {
|
||||
EXPECT_EQ(bounded_subset_sum.CurrentMax(), 31);
|
||||
}
|
||||
|
||||
TEST(MaxBoundedSubsetSumTest, CheckMaxIfAdded) {
|
||||
MaxBoundedSubsetSum bounded_subset_sum(34);
|
||||
bounded_subset_sum.Add(10);
|
||||
bounded_subset_sum.Add(10);
|
||||
bounded_subset_sum.Add(10);
|
||||
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(12), 32);
|
||||
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(15), 30);
|
||||
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(34), 34);
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
bounded_subset_sum.Add(18);
|
||||
}
|
||||
EXPECT_EQ(bounded_subset_sum.CurrentMax(), 30);
|
||||
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(5), 33);
|
||||
}
|
||||
|
||||
static void BM_bounded_subset_sum(benchmark::State& state) {
|
||||
random_engine_t random_;
|
||||
const int num_items = state.range(0);
|
||||
|
||||
@@ -635,18 +635,21 @@ const std::vector<Literal>& SharedTreeWorker::DecisionReason(int level) {
|
||||
}
|
||||
|
||||
bool SharedTreeWorker::AddDecisionImplication(Literal lit, int level) {
|
||||
CHECK_GT(level, 0);
|
||||
CHECK_NE(lit.Index(), kNoLiteralIndex);
|
||||
CHECK(!sat_solver_->Assignment().LiteralIsTrue(lit));
|
||||
absl::Span<const Literal> reason = DecisionReason(level);
|
||||
if (sat_solver_->Assignment().LiteralIsFalse(lit)) {
|
||||
VLOG(2) << "Closing subtree via impl at " << level + 1
|
||||
<< " assigned=" << assigned_tree_.MaxLevel();
|
||||
integer_trail_->ReportConflict(DecisionReason(level), {});
|
||||
trail_->MutableConflict()->assign(reason.begin(), reason.end());
|
||||
manager_->CloseTree(assigned_tree_, level);
|
||||
assigned_tree_literals_.clear();
|
||||
return false;
|
||||
}
|
||||
VLOG(2) << "Learned shared clause";
|
||||
return integer_trail_->EnqueueLiteral(lit, DecisionReason(level), {});
|
||||
trail_->GetEmptyVectorToStoreReason()->assign(reason.begin(), reason.end());
|
||||
return trail_->EnqueueWithStoredReason(kNoClauseId, lit);
|
||||
}
|
||||
|
||||
bool SharedTreeWorker::AddImplications() {
|
||||
|
||||
Reference in New Issue
Block a user