[CP-SAT] morework on lrat

This commit is contained in:
Laurent Perron
2025-12-04 15:51:17 +01:00
committed by Corentin Le Molgat
parent 69a94a445e
commit 3b18bdd58b
23 changed files with 1083 additions and 194 deletions

View File

@@ -298,6 +298,29 @@ cc_library(
],
)
cc_library(
name = "gate_utils",
hdrs = ["gate_utils.h"],
deps = [
":sat_base",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/types:span",
],
)
cc_test(
name = "gate_utils_test",
srcs = ["gate_utils_test.cc"],
deps = [
":gate_utils",
":sat_base",
"//ortools/base:gmock_main",
"//ortools/base:logging",
"@abseil-cpp//absl/random",
"@abseil-cpp//absl/types:span",
],
)
cc_proto_library(
name = "cp_model_cc_proto",
visibility = ["//visibility:public"],
@@ -1585,6 +1608,7 @@ cc_library(
hdrs = ["sat_inprocessing.h"],
deps = [
":clause",
":gate_utils",
":linear_programming_constraint",
":lrat_proof_handler",
":model",
@@ -1939,7 +1963,9 @@ cc_library(
hdrs = ["integer.h"],
visibility = ["//visibility:public"],
deps = [
":clause",
":integer_base",
":lrat_proof_handler",
":model",
":sat_base",
":sat_parameters_cc_proto",
@@ -4279,6 +4305,7 @@ cc_library(
":recordio",
":sat_base",
":synchronization",
":util",
"//ortools/base:file",
"//ortools/base:intops",
"//ortools/base:timer",
@@ -4291,6 +4318,19 @@ cc_library(
],
)
cc_test(
name = "lrat_proof_handler_test",
srcs = ["lrat_proof_handler_test.cc"],
deps = [
":lrat_proof_handler",
":model",
":sat_base",
":util",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/types:span",
],
)
cc_test(
name = "lrat_checker_test",
srcs = ["lrat_checker_test.cc"],

View File

@@ -404,7 +404,7 @@ bool AllDifferentConstraint::Propagate() {
}
}
return trail_->EnqueueWithStoredReason(x_lit.Negated());
return trail_->EnqueueWithStoredReason(kNoClauseId, x_lit.Negated());
}
}
}

View File

@@ -308,7 +308,7 @@ bool CircuitPropagator::Propagate() {
std::vector<Literal>* reason = trail_.GetEmptyVectorToStoreReason();
FillReasonForPath(start_node, reason);
enforcement_helper_.AddEnforcementReason(enforcement_id_, reason);
if (!trail_.EnqueueWithStoredReason(literal.Negated())) {
if (!trail_.EnqueueWithStoredReason(kNoClauseId, literal.Negated())) {
return false;
}
}
@@ -357,7 +357,8 @@ bool CircuitPropagator::Propagate() {
if (extra_reason != kFalseLiteralIndex) {
reason->push_back(Literal(extra_reason));
}
const bool ok = trail_.EnqueueWithStoredReason(literal.Negated());
const bool ok =
trail_.EnqueueWithStoredReason(kNoClauseId, literal.Negated());
if (!ok) return false;
continue;
}
@@ -398,7 +399,7 @@ bool CircuitPropagator::Propagate() {
std::vector<Literal>* reason = trail_.GetEmptyVectorToStoreReason();
FillReasonForPath(start_node, reason);
enforcement_helper_.AddEnforcementReason(enforcement_id_, reason);
const bool ok = trail_.EnqueueWithStoredReason(literal);
const bool ok = trail_.EnqueueWithStoredReason(kNoClauseId, literal);
if (!ok) return false;
} else {
trail_.EnqueueWithSameReasonAs(literal, variable_with_same_reason);
@@ -669,8 +670,8 @@ bool CircuitCoveringPropagator::Propagate() {
!trail_->Assignment().LiteralIsFalse(graph_[end][start])) {
auto* reason = trail_->GetEmptyVectorToStoreReason();
FillFixedPathInReason(start, end, reason);
const bool ok =
trail_->EnqueueWithStoredReason(graph_[end][start].Negated());
const bool ok = trail_->EnqueueWithStoredReason(
kNoClauseId, graph_[end][start].Negated());
if (!ok) return false;
}
}

View File

@@ -477,20 +477,7 @@ bool ClauseManager::InprocessingAddUnitClause(ClauseId unit_clause_id,
bool ClauseManager::InprocessingFixLiteral(
Literal true_literal, absl::Span<const ClauseId> clause_ids) {
DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
if (trail_->Assignment().LiteralIsTrue(true_literal)) return true;
ClauseId clause_id = kNoClauseId;
if (lrat_proof_handler_ != nullptr) {
clause_id = clause_id_generator_->GetNextId();
lrat_proof_handler_->AddInferredClause(clause_id, {true_literal},
clause_ids);
}
trail_->EnqueueWithUnitReason(clause_id, true_literal);
// Even when all clauses are detached, we can propagate the implication
// graph and we do that right away.
return implication_graph_->Propagate(trail_);
return implication_graph_->FixLiteral(true_literal, clause_ids);
}
void ClauseManager::ChangeLbdIfBetter(SatClause* clause, int new_lbd) {
@@ -943,21 +930,17 @@ bool BinaryImplicationGraph::AddBinaryClauseInternal(
ClauseId id, Literal a, Literal b, bool change_reason,
bool delete_non_representative_id) {
SCOPED_TIME_STAT(&stats_);
DCHECK_GE(a.Index(), 0);
DCHECK_GE(b.Index(), 0);
// Tricky: If this is the first clause, the propagator will be added and
// assumed to be in a "propagated" state. This makes sure this is the case.
if (no_constraint_ever_added_) propagation_trail_index_ = trail_->Index();
no_constraint_ever_added_ = false;
Literal rep_a = a;
Literal rep_b = b;
ClauseId rep_id = kNoClauseId;
if (is_redundant_[a.Index()]) {
rep_a = Literal(representative_of_[a.Index()]);
}
if (is_redundant_[b.Index()]) {
rep_b = Literal(representative_of_[b.Index()]);
}
const Literal rep_a = RepresentativeOf(a);
const Literal rep_b = RepresentativeOf(b);
if (rep_a == rep_b.Negated()) return true;
if (lrat_proof_handler_ != nullptr) {
@@ -1361,15 +1344,50 @@ void BinaryImplicationGraph::Reimply(Trail* trail, int old_trail_index) {
// can take advantage of that.
void BinaryImplicationGraph::MinimizeConflictFirst(
const Trail& trail, std::vector<Literal>* conflict,
SparseBitset<BooleanVariable>* marked, std::vector<ClauseId>* clause_ids) {
SparseBitset<BooleanVariable>* marked, std::vector<ClauseId>* clause_ids,
bool also_use_decisions) {
SCOPED_TIME_STAT(&stats_);
DCHECK(!conflict->empty());
is_marked_.ClearAndResize(LiteralIndex(implications_and_amos_.size()));
tmp_to_keep_.clear();
tmp_to_keep_.push_back(conflict->front().Negated());
if (lrat_proof_handler_ != nullptr) {
MarkDescendants</*fill_implied_by=*/true>(conflict->front().Negated());
} else {
MarkDescendants(conflict->front().Negated());
}
// Because the decision cannot be removed from the conflict, we know they will
// stay, so it is okay to see what they propagate in the binary implication
// graph. Technically we could do that for any first literal of a decision
// level. Improve?
std::vector<std::pair<Literal, int>> decisions;
if (also_use_decisions) {
for (int i = 1; i < conflict->size(); ++i) {
const auto& info = trail_->Info((*conflict)[i].Variable());
if (info.type == AssignmentType::kSearchDecision) {
decisions.push_back({(*conflict)[i].Negated(), info.level});
}
}
absl::c_stable_sort(decisions, [](const std::pair<LiteralIndex, int>& a,
const std::pair<LiteralIndex, int>& b) {
return a.second > b.second;
});
}
// Keep marking everything propagated by the decisions, and make sure we
// don't remove the one from which we called MarkDescendants().
for (const auto [literal, unused_level] : decisions) {
if (is_marked_[literal]) continue;
tmp_to_keep_.push_back(literal);
if (lrat_proof_handler_ != nullptr) {
MarkDescendants</*fill_implied_by=*/true>(literal);
} else {
MarkDescendants(literal);
}
}
for (const LiteralIndex i : is_marked_.PositionsSetAtLeastOnce()) {
// TODO(user): if this is false, then we actually have a conflict of size 2.
// This can only happen if the binary clause was not propagated properly
@@ -1379,6 +1397,9 @@ void BinaryImplicationGraph::MinimizeConflictFirst(
marked->Set(Literal(i).Variable());
}
}
// Remove all marked literals from the conflict.
for (const Literal l : tmp_to_keep_) is_marked_.Clear(l);
if (lrat_proof_handler_ != nullptr) {
RemoveRedundantLiterals</*fill_clause_ids=*/true>(conflict, clause_ids);
} else {

View File

@@ -46,15 +46,6 @@
namespace operations_research {
namespace sat {
// A generator for ClauseIds. Not thread-safe.
class ClauseIdGenerator {
public:
ClauseId GetNextId() { return ClauseId(next_id_++); }
private:
ClauseId next_id_ = ClauseId(1);
};
// This is how the SatSolver stores a clause. A clause is just a disjunction of
// literals. In many places, we just use vector<literal> to encode one. But in
// the critical propagation code, we use this class to remove one memory
@@ -709,7 +700,8 @@ class BinaryImplicationGraph : public SatPropagator {
// details about the different algorithms.
void MinimizeConflictFirst(const Trail& trail, std::vector<Literal>* c,
SparseBitset<BooleanVariable>* marked,
std::vector<ClauseId>* clause_ids);
std::vector<ClauseId>* clause_ids,
bool also_use_decisions);
// Appends the IDs of the unit and binary clauses that imply the given
// literals to `clause_ids`. Either `MinimizeConflictFirst` or
@@ -985,6 +977,12 @@ class BinaryImplicationGraph : public SatPropagator {
return implications_and_amos_[lit].offsets();
}
// Simple wrapper to not forget to output newly fixed variable to the DRAT
// or LRAT proof (with clause_ids as proof) if needed. This will propagate
// right away the implications.
bool FixLiteral(Literal true_literal,
absl::Span<const ClauseId> clause_ids = {});
private:
friend class LratEquivalenceHelper;
@@ -1005,12 +1003,6 @@ class BinaryImplicationGraph : public SatPropagator {
clause_id_[{a, b}] = id;
}
// Simple wrapper to not forget to output newly fixed variable to the DRAT
// or LRAT proof (with clause_ids as proof) if needed. This will propagate
// right away the implications.
bool FixLiteral(Literal true_literal,
absl::Span<const ClauseId> clause_ids = {});
// Removes any literal whose negation is marked (except the first one). If
// `fill_clause_ids` is true, fills the LRAT proof for this change in
// `clause_ids` (this requires an LRAT proof handler to be set, and
@@ -1117,12 +1109,14 @@ class BinaryImplicationGraph : public SatPropagator {
int64_t num_redundant_literals_ = 0;
// Bitset used by MinimizeClause().
//
// TODO(user): use the same one as the one used in the classic minimization
// because they are already initialized. Moreover they contains more
// information.
// information?
SparseBitset<LiteralIndex> is_marked_;
SparseBitset<LiteralIndex> tmp_bitset_;
SparseBitset<LiteralIndex> is_simplified_;
std::vector<Literal> tmp_to_keep_;
// Used by AppendImplicationChains() to avoid processing a unit clause several
// times.

View File

@@ -858,9 +858,6 @@ std::vector<SatParameters> GetFullWorkerParameters(
ModelHasSchedulingConstraints(cp_model);
// Our current set of strategies
//
// TODO(user): Avoid launching two strategies if they are the same,
// like if there is no lp, or everything is already linearized at level 1.
std::vector<std::string> names;
// Starts by adding user specified ones.
@@ -1021,7 +1018,15 @@ std::vector<SatParameters> GetFullWorkerParameters(
if (result.size() > num_to_keep) {
result.resize(std::max(0, num_to_keep));
} else if (!result.empty() && num_to_keep >= 0) {
// If we have less parameters, duplicate the first one until we have enough.
// This is a bit hacky but easily allow to do experiment with n times the
// same subsolver.
while (result.size() < num_to_keep) {
result.push_back(result[0]);
}
}
return result;
}

199
ortools/sat/gate_utils.h Normal file
View File

@@ -0,0 +1,199 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Functions to manipulate a "small" truth table where
// f(X0, X1, X2) is true iff bitmask[X0 + (X1 << 1) + (X2 << 2)] is true.
#ifndef ORTOOLS_SAT_GATE_UTILS_H_
#define ORTOOLS_SAT_GATE_UTILS_H_
#include <cstdint>
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "ortools/sat/sat_base.h"
namespace operations_research::sat {
using SmallBitset = uint32_t;
// Sort the key and modify the truth table accordingly.
//
// Note that we don't deal with identical key here, but the function
// CanonicalizeFunctionTruthTable() does, and that is sufficient for our use
// case.
template <typename VarOrLiteral>
void CanonicalizeTruthTable(absl::Span<VarOrLiteral> key,
SmallBitset& bitmask) {
const int num_bits = key.size();
CHECK_EQ(bitmask >> (1 << num_bits), 0);
for (int i = 0; i < num_bits; ++i) {
for (int j = i + 1; j < num_bits; ++j) {
if (key[i] <= key[j]) continue;
std::swap(key[i], key[j]);
// We need to swap bit positions i and j in bitmask.
SmallBitset new_bitmask = 0;
for (int p = 0; p < (1 << num_bits); ++p) {
const int value_i = (p >> i) & 1;
const int value_j = (p >> j) & 1;
int new_p = p;
new_p ^= (value_i << i) ^ (value_j << j); // Clear.
new_p ^= (value_i << j) ^ (value_j << i); // Swap.
new_bitmask |= ((bitmask >> p) & 1) << new_p;
}
bitmask = new_bitmask;
CHECK_EQ(bitmask >> (1 << num_bits), 0)
<< i << " " << j << " " << num_bits;
}
}
CHECK(std::is_sorted(key.begin(), key.end()));
}
// Given a clause, return the truth table corresponding to it.
// Namely, a single value should be excluded.
inline void FillKeyAndBitmask(absl::Span<const Literal> clause,
absl::Span<BooleanVariable> key,
SmallBitset& bitmask) {
CHECK_EQ(clause.size(), key.size());
const int num_bits = clause.size();
bitmask = ~SmallBitset(0) >> (32 - (1 << num_bits)); // All allowed;
CHECK_EQ(bitmask >> (1 << num_bits), 0) << num_bits;
SmallBitset bit_to_remove = 0;
for (int i = 0; i < num_bits; ++i) {
key[i] = clause[i].Variable();
bit_to_remove |= (clause[i].IsPositive() ? 0 : 1) << i;
}
bitmask ^= (SmallBitset(1) << bit_to_remove);
CHECK_EQ(bitmask >> (1 << num_bits), 0) << bit_to_remove << " " << num_bits;
CanonicalizeTruthTable<BooleanVariable>(key, bitmask);
}
// Returns true iff the truth table encoded in bitmask encode a function
// Xi = f(Xj, j != i);
template <int num_bits>
bool IsFunction(int i, SmallBitset truth_table) {
DCHECK_GE(i, 0);
DCHECK_LT(i, num_bits);
// We need to check that there is never two possibilities for Xi.
for (int p = 0; p < (1 << num_bits); ++p) {
if ((truth_table >> p) & (truth_table >> (p ^ (1 << i))) & 1) return false;
}
return true;
}
inline int AddHoleAtPosition(int i, int bitset) {
return (bitset & ((1 << i) - 1)) + ((bitset >> i) << (i + 1));
}
// The function is target = function_values[inputs as bit position].
//
// TODO(user): This can be optimized with more bit twiddling if needed.
inline int CanonicalizeFunctionTruthTable(LiteralIndex& target,
absl::Span<LiteralIndex> inputs,
int& int_function_values) {
// We want to fit on an int.
CHECK_LE(inputs.size(), 4);
// We assume smaller type.
SmallBitset function_values = int_function_values;
const int num_bits = inputs.size();
CHECK_LE(num_bits, 4); // Truth table must fit on an int.
const SmallBitset all_one = (1 << (1 << num_bits)) - 1;
CHECK_EQ(function_values & ~all_one, 0);
// Make sure target is positive.
if (!Literal(target).IsPositive()) {
target = Literal(target).Negated();
function_values = function_values ^ all_one;
CHECK_EQ(function_values >> (1 << num_bits), 0);
}
// Make sure all inputs are positive.
for (int i = 0; i < num_bits; ++i) {
if (Literal(inputs[i]).IsPositive()) continue;
inputs[i] = Literal(inputs[i]).NegatedIndex();
// Position p go to position (p ^ (1 << i)).
SmallBitset new_truth_table = 0;
const SmallBitset to_xor = 1 << i;
for (int p = 0; p < (1 << num_bits); ++p) {
new_truth_table |= ((function_values >> p) & 1) << (p ^ to_xor);
}
function_values = new_truth_table;
CHECK_EQ(function_values >> (1 << num_bits), 0);
}
// Sort the inputs now.
CanonicalizeTruthTable<LiteralIndex>(inputs, function_values);
CHECK_EQ(function_values >> (1 << num_bits), 0);
// Merge identical variables.
for (int i = 0; i < inputs.size(); ++i) {
for (int j = i + 1; j < inputs.size();) {
if (inputs[i] == inputs[j]) {
// Lets remove input j.
for (int k = j; k + 1 < inputs.size(); ++k) inputs[k] = inputs[k + 1];
inputs.remove_suffix(1);
SmallBitset new_truth_table = 0;
for (int p = 0; p < (1 << inputs.size()); ++p) {
int extended_p = AddHoleAtPosition(j, p);
extended_p |= ((p >> i) & 1) << j; // fill it with bit i.
new_truth_table |= ((function_values >> extended_p) & 1) << p;
}
function_values = new_truth_table;
} else {
++j;
}
}
}
// Lower arity?
// This can happen if the output do not depend on one of the inputs.
for (int i = 0; i < inputs.size();) {
bool remove = true;
for (int p = 0; p < (1 << inputs.size()); ++p) {
if (((function_values >> p) & 1) !=
((function_values >> (p ^ (1 << i))) & 1)) {
remove = false;
break;
}
}
if (remove) {
// Lets remove input i.
for (int k = i; k + 1 < inputs.size(); ++k) inputs[k] = inputs[k + 1];
inputs.remove_suffix(1);
SmallBitset new_truth_table = 0;
for (int p = 0; p < (1 << inputs.size()); ++p) {
const int extended_p = AddHoleAtPosition(i, p);
new_truth_table |= ((function_values >> extended_p) & 1) << p;
}
function_values = new_truth_table;
} else {
++i;
}
}
int_function_values = function_values;
return inputs.size();
}
} // namespace operations_research::sat
#endif // ORTOOLS_SAT_GATE_UTILS_H_

View File

@@ -0,0 +1,147 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/sat/gate_utils.h"
#include <array>
#include <bitset>
#include <cstdint>
#include <vector>
#include "absl/random/random.h"
#include "absl/types/span.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/base/logging.h"
#include "ortools/sat/sat_base.h"
namespace operations_research::sat {
namespace {
TEST(CanonicalizeTruthTableTest, BasicBehavior1) {
std::array<int, 3> key = {0, 2, 1};
// no change here.
SmallBitset bitmask = 0b10101010;
CanonicalizeTruthTable<int>(absl::MakeSpan(key), bitmask);
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b10101010));
}
TEST(CanonicalizeTruthTableTest, BasicBehavior2) {
std::array<int, 3> key = {2, 0, 1};
SmallBitset bitmask = 0b10101010;
CanonicalizeTruthTable<int>(absl::MakeSpan(key), bitmask);
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11110000));
}
TEST(CanonicalizeTruthTableTest, BasicBehavior3) {
std::array<int, 3> key = {1, 0, 2};
SmallBitset bitmask = 0b10101010;
CanonicalizeTruthTable<int>(absl::MakeSpan(key), bitmask);
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11001100));
}
TEST(FillKeyAndBitmaskTest, BasicBehavior1) {
std::array<BooleanVariable, 3> key;
SmallBitset bitmask;
FillKeyAndBitmask({Literal(+1), Literal(-2), Literal(+3)},
absl::MakeSpan(key), bitmask);
EXPECT_THAT(key,
::testing::ElementsAre(BooleanVariable(0), BooleanVariable(1),
BooleanVariable(2)));
// The bit number 2 = 0b010 should be off.
EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11111011));
}
TEST(IsFunctionTest, ConstantValue) {
EXPECT_TRUE(IsFunction<3>(0, 0b10101010));
EXPECT_FALSE(IsFunction<3>(1, 0b10101010));
EXPECT_FALSE(IsFunction<3>(2, 0b10101010));
}
TEST(AddHoleAtPositionTest, BasicTest) {
EXPECT_EQ(AddHoleAtPosition(0, 0xFF), 0b111111110);
EXPECT_EQ(AddHoleAtPosition(1, 0xFF), 0b111111101);
EXPECT_EQ(AddHoleAtPosition(8, 0xFF), 0b011111111);
}
TEST(CanonicalizeFunctionTruthTableTest, RandomTest) {
absl::BitGen random;
const int num_vars = 8;
for (int num_test = 0; num_test < 1000; ++num_test) {
// Lets generate a random function on k random variables.
const int k = absl::Uniform(random, 0, 4);
const int table = absl::Uniform<uint64_t>(random, 0, 1 << (1 << k));
const Literal output(BooleanVariable(100), absl::Bernoulli(random, 0.5));
std::vector<Literal> inputs;
for (int i = 0; i < k; ++i) {
inputs.push_back(
Literal(BooleanVariable(absl::Uniform(random, 0, num_vars)),
absl::Bernoulli(random, 0.5)));
}
LiteralIndex new_output_index = output.Index();
std::vector<Literal> new_inputs = inputs;
std::vector<LiteralIndex> new_inputs_index;
for (const Literal lit : new_inputs) {
new_inputs_index.push_back(lit.Index());
}
int new_table = table;
const int new_size = CanonicalizeFunctionTruthTable(
new_output_index, absl::MakeSpan(new_inputs_index), new_table);
new_inputs_index.resize(new_size);
new_inputs.clear();
for (const LiteralIndex lit_index : new_inputs_index) {
new_inputs.push_back(Literal(lit_index));
}
// Log before potential failure.
LOG(INFO) << "IN arity=" << k << " " << output << " = f(" << inputs << ") "
<< std::bitset<16>(table);
LOG(INFO) << "OUT arity=" << new_size << " " << Literal(new_output_index)
<< " = f(" << new_inputs << ") " << std::bitset<16>(new_table);
// Now check that both function always take the same value.
for (int m = 0; m < (1 << num_vars); ++m) {
int index = 0;
for (int i = 0; i < inputs.size(); ++i) {
const Literal lit = inputs[i];
int value = (m >> lit.Variable().value()) & 1;
if (!lit.IsPositive()) value = 1 - value;
index |= value << i;
}
const int target_value = (table >> index) & 1;
int new_index = 0;
for (int i = 0; i < new_inputs.size(); ++i) {
const Literal lit = new_inputs[i];
int value = (m >> lit.Variable().value()) & 1;
if (!lit.IsPositive()) value = 1 - value;
new_index |= value << i;
}
const int new_target_value = (new_table >> new_index) & 1;
if (output == Literal(new_output_index)) {
ASSERT_EQ(target_value, new_target_value) << index << " " << new_index;
} else {
ASSERT_EQ(output, Literal(new_output_index).Negated());
ASSERT_EQ(target_value, 1 - new_target_value)
<< index << " " << new_index;
}
}
}
}
} // namespace
} // namespace operations_research::sat

View File

@@ -21,6 +21,7 @@
#include <deque>
#include <functional>
#include <limits>
#include <optional>
#include <string>
#include <utility>
#include <vector>
@@ -34,7 +35,9 @@
#include "absl/types/span.h"
#include "ortools/base/logging.h"
#include "ortools/base/strong_vector.h"
#include "ortools/sat/clause.h"
#include "ortools/sat/integer_base.h"
#include "ortools/sat/lrat_proof_handler.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/sat_parameters.pb.h"
@@ -105,6 +108,8 @@ class IntegerEncoder {
explicit IntegerEncoder(Model* model)
: sat_solver_(model->GetOrCreate<SatSolver>()),
trail_(model->GetOrCreate<Trail>()),
clause_id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
lrat_proof_handler_(model->Mutable<LratProofHandler>()),
delayed_to_fix_(model->GetOrCreate<DelayedRootLevelDeduction>()),
domains_(*model->GetOrCreate<IntegerDomains>()),
num_created_variables_(0) {}
@@ -291,9 +296,15 @@ class IntegerEncoder {
Literal(sat_solver_->NewBooleanVariable(), true);
literal_index_true_ = literal_true.Index();
// This might return false if we are already UNSAT.
// TODO(user): Make sure we abort right away on unsat!
(void)sat_solver_->AddUnitClause(literal_true);
ClauseId clause_id = kNoClauseId;
if (lrat_proof_handler_ != nullptr) {
clause_id = clause_id_generator_->GetNextId();
// We cannot prove `literal_true` by unit propagation, but we can with a
// RAT inference (trivial here since there are no clauses containing the
// negation of the pivot `literal_true`).
lrat_proof_handler_->AddInferredClause(clause_id, {literal_true}, {});
}
trail_->EnqueueWithUnitReason(clause_id, literal_true);
}
return Literal(literal_index_true_);
}
@@ -328,6 +339,8 @@ class IntegerEncoder {
SatSolver* sat_solver_;
Trail* trail_;
ClauseIdGenerator* clause_id_generator_;
LratProofHandler* lrat_proof_handler_;
DelayedRootLevelDeduction* delayed_to_fix_;
const IntegerDomains& domains_;
@@ -339,7 +352,7 @@ class IntegerEncoder {
// corresponding to the same variable).
//
// Note that we only keep this for positive variable.
// The one for the negation can be infered by it.
// The one for the negation can be inferred by it.
//
// Like x >= 1 x >= 4 x >= 5
// Correspond to x <= 0 x <= 3 x <= 4

View File

@@ -512,8 +512,8 @@ void IntegerConflictResolution::ComputeFirstUIPConflict(
}
if (uip_found) {
if (params_.binary_minimization_algorithm() ==
SatParameters::BINARY_MINIMIZATION_FIRST) {
if (params_.binary_minimization_algorithm() !=
SatParameters::NO_BINARY_MINIMIZATION) {
if (conflict->empty()) {
// This one will always stay in the conflict, even after
// minimization. So we can use it to minimize the conflict and avoid

View File

@@ -14,6 +14,7 @@
#include "ortools/sat/lrat_proof_handler.h"
#include <algorithm>
#include <bitset>
#include <cstdint>
#include <cstdlib>
#include <fstream>
@@ -23,6 +24,7 @@
#include <string_view>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
@@ -40,6 +42,7 @@
#include "ortools/sat/recordio.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/synchronization.h"
#include "ortools/sat/util.h"
#if defined(_MSC_VER)
ABSL_FLAG(std::string, cp_model_drat_output, ".\\drat.txt",
@@ -423,6 +426,7 @@ std::unique_ptr<LratProofHandler> LratProofHandler::MaybeCreate(Model* model) {
LratProofHandler::LratProofHandler(Model* model)
: id_(model->GetOrCreate<SharedLratProofStatus>()->NewSubSolverId()),
id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
proof_status_(model->GetOrCreate<SharedLratProofStatus>()) {
const SatParameters& params = *model->GetOrCreate<SatParameters>();
if (params.check_lrat_proof()) {
@@ -448,7 +452,7 @@ LratProofHandler::LratProofHandler(Model* model)
bool LratProofHandler::AddProblemClause(ClauseId id,
absl::Span<const Literal> clause) {
VLOG(1) << "AddProblemClause: id=" << id
VLOG(2) << "AddProblemClause: id=" << id
<< " literals=" << absl::StrJoin(clause, ",");
if (all_problem_clauses_loaded_ && debug_crash_on_error_) {
LOG(FATAL) << "LRAT error: problem clauses must not be added after "
@@ -481,7 +485,7 @@ bool LratProofHandler::AddInferredClause(
ClauseId id, absl::Span<const Literal> clause,
absl::Span<const ClauseId> unit_ids,
absl::Span<const LratChecker::RatIds> rat) {
VLOG(1) << "AddInferredClause: id=" << id
VLOG(2) << "AddInferredClause: id=" << id
<< " literals=" << absl::StrJoin(clause, ",")
<< " unit_ids=" << absl::StrJoin(unit_ids, ",") << " rat={"
<< absl::StrJoin(rat, " ") << "}";
@@ -508,7 +512,7 @@ bool LratProofHandler::AddInferredClause(
bool LratProofHandler::AddImportedClause(ClauseId id,
absl::Span<const Literal> clause) {
VLOG(1) << "AddImportedClause: id=" << id
VLOG(2) << "AddImportedClause: id=" << id
<< " literals=" << absl::StrJoin(clause, ",");
if (lrat_checker_ != nullptr &&
!lrat_checker_->AddProblemClause(id, clause)) {
@@ -526,7 +530,7 @@ bool LratProofHandler::AddImportedClause(ClauseId id,
bool LratProofHandler::AddAssumedClause(ClauseId id,
absl::Span<const Literal> clause) {
VLOG(1) << "AddAssumedClause: id=" << id
VLOG(2) << "AddAssumedClause: id=" << id
<< " literals=" << absl::StrJoin(clause, ",");
if (debug_crash_on_error_) {
LOG(FATAL) << "LRAT error: assumed clauses are not supposed to happen";
@@ -571,7 +575,7 @@ void LratProofHandler::DeleteClause(ClauseId id,
delete_pinned_clause_ = true;
return;
}
VLOG(1) << "DeleteClause: id=" << id
VLOG(2) << "DeleteClause: id=" << id
<< " literals=" << absl::StrJoin(clause, ",");
if (lrat_checker_ != nullptr) {
lrat_checker_->DeleteClauses({id});
@@ -638,5 +642,177 @@ void LratProofHandler::Close(bool model_is_unsat) {
}
}
bool LratProofHandler::AddAndProveInferredClauseByEnumeration(
ClauseId new_id, absl::Span<const Literal> new_clause,
absl::Span<const ClauseId> ids_for_proof,
const CompactVectorVector<int, Literal>& clauses_for_proof) {
CHECK_EQ(ids_for_proof.size(), clauses_for_proof.size());
CHECK(!clauses_for_proof.empty());
// First we count the number of variables appearing and have a separate dense
// index for them. The first new_clause.size() dense index are exactly the
// literal of the new_clause.
absl::flat_hash_map<BooleanVariable, int> to_dense_index;
std::vector<Literal> dense_index_to_literal;
dense_index_to_literal.assign(new_clause.begin(), new_clause.end());
for (const Literal lit : new_clause) {
const auto [it, inserted] =
to_dense_index.insert({lit.Variable(), to_dense_index.size()});
if (!inserted) {
VLOG(2) << "Duplicate variable in new_clause! " << new_clause;
return false;
}
}
// Then any new BooleanVariable appearing get the next dense index.
std::vector<Literal> relevant_literals;
for (int i = 0; i < clauses_for_proof.size(); ++i) {
for (const Literal lit : clauses_for_proof[i]) {
const auto [it, inserted] =
to_dense_index.insert({lit.Variable(), to_dense_index.size()});
if (inserted) {
relevant_literals.push_back(lit);
}
}
}
// Too many variables.
//
// TODO(user): The limit can be increased a bit if needed.
if (to_dense_index.size() > 6) {
VLOG(2) << "Too many variables";
return false;
}
// For the proof we will need all clauses of the form
// {new_clause, l0, ..., lk} for all k in [0, n) and
// li = relevant_literals[i] OR relevant_literals[i].Negated().
//
// That give us 2^(n + 1) intermediate clauses.
// Their ids will be stored in (1 << k) + binary_encoding_of_the_li.
const int n = to_dense_index.size() - new_clause.size();
CHECK_EQ(n, relevant_literals.size());
const int num_intermediates = 1 << (n + 1);
std::vector<ClauseId> ids(num_intermediates, kNoClauseId);
if (n == 0) {
VLOG(2) << "Nothing to prove! An existing clause is included inside";
return false;
}
VLOG(2) << "Starting proof n= " << n << " " << num_intermediates;
// Any initial clause can be used to prove all the intermediates that contains
// it. Note that this code supports duplicate literals in the clauses.
for (int i = 0; i < clauses_for_proof.size(); ++i) {
bool skip = false;
int base_index = 0;
int mask = 0;
int k = 0;
for (const Literal lit : clauses_for_proof[i]) {
const int dense_index = to_dense_index[lit.Variable()];
if (dense_index < new_clause.size()) {
// Check that the literal is the same as in the new_clause, if
// not, this clause will not be needed for the proof.
if (lit != new_clause[dense_index]) {
skip = true;
break;
}
} else {
k = std::max(k, dense_index);
mask |= 1 << dense_index;
if (lit == relevant_literals[dense_index - new_clause.size()]) {
base_index |= 1 << dense_index;
}
}
}
if (skip) continue;
mask >>= new_clause.size();
base_index >>= new_clause.size();
k = k + 1 - new_clause.size();
VLOG(2) << k << " " << std::bitset<8>(mask) << " "
<< std::bitset<8>(base_index);
// TODO(user): we could be faster here if it become needed.
for (int m = 0; m < (1 << n); ++m) {
if ((m & mask) != base_index) continue; // not included.
const int index = m | base_index;
for (int j = k; j <= n; ++j) {
if (index >> j == 0) {
VLOG(2) << "Included in " << j << " "
<< std::bitset<8>((1 << j) | index);
ids[(1 << j) | index] = ids_for_proof[i];
}
}
}
}
// We can prove the others by decreasing k.
std::vector<Literal> tmp_clause;
tmp_clause.assign(new_clause.begin(), new_clause.end());
std::vector<bool> id_need_deletion(num_intermediates, false);
for (int k = n; --k >= 0;) {
for (int m = 0; m < (1 << k); ++m) {
const int index = (1 << k) | m;
if (ids[index] != kNoClauseId) continue; // Already proven.
// Generate the tmp_clause.
tmp_clause.resize(new_clause.size());
for (int i = 0; i < k; ++i) {
tmp_clause.push_back(relevant_literals[i]);
if (((index >> i) & 1) == 0) {
tmp_clause.back() = tmp_clause.back().Negated();
}
}
// Prove it from the two clauses at k + 1.
const int higher1 = index ^ ((0b11) << k);
const int higher2 = index ^ ((0b10) << k);
const ClauseId id1 = ids[higher1];
const ClauseId id2 = ids[higher2];
if (id1 == kNoClauseId || id2 == kNoClauseId) {
VLOG(2) << "missing higher level clauses in the resolution."
<< " index: " << std::bitset<8>(index)
<< " higher1: " << std::bitset<8>(higher1)
<< " higher2: " << std::bitset<8>(higher2);
return false;
}
ids[index] = k == 0 ? new_id : id_generator_->GetNextId();
if (k != 0) {
VLOG(2) << "temporary !! " << ids[index] << " " << tmp_clause;
id_need_deletion[index] = true; // temporary.
}
if (!AddInferredClause(ids[index], tmp_clause, {id1, id2})) {
VLOG(2) << "Failed resolution step";
return false;
}
if (k == 0) {
DCHECK_EQ(new_clause, tmp_clause);
VLOG(2) << "Proven " << new_clause << "!";
}
// Lets delete the ids if they were temporary.
if (id_need_deletion[higher1]) {
tmp_clause.push_back(relevant_literals[k].Negated());
VLOG(2) << "deleting: " << id1 << " " << tmp_clause;
DeleteClause(id1, tmp_clause);
tmp_clause.pop_back();
}
if (id_need_deletion[higher2]) {
tmp_clause.push_back(relevant_literals[k]);
VLOG(2) << "deleting: " << id2 << " " << tmp_clause;
DeleteClause(id2, tmp_clause);
tmp_clause.pop_back();
}
}
}
return true;
}
} // namespace sat
} // namespace operations_research

View File

@@ -33,6 +33,7 @@
#include "ortools/sat/recordio.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/synchronization.h"
#include "ortools/sat/util.h"
namespace operations_research {
namespace sat {
@@ -149,6 +150,19 @@ class LratProofHandler {
absl::Span<const ClauseId> unit_ids,
absl::Span<const LratChecker::RatIds> rat = {});
// This assumes that the 'new_clause' to prove and all the ones needed for the
// proof only touch a small number of variables (<= 6). It will then prove the
// new clause by enumerating all possibilities and producing the relevant
// intermediate LRAT RUP steps.
//
// The last two arguments must have the same size and are in one to one
// correspondence. Note that we might not need all the given clauses in the
// proof.
bool AddAndProveInferredClauseByEnumeration(
ClauseId new_id, absl::Span<const Literal> new_clause,
absl::Span<const ClauseId> ids_for_proof,
const CompactVectorVector<int, Literal>& clauses_for_proof);
// Adds a clause which was inferred by another worker. Returns true if
// successful (the operation can fail if LRAT checks are enabled, and the ID
// is already used by another clause).
@@ -190,6 +204,7 @@ class LratProofHandler {
bool LratError() const;
const int id_;
ClauseIdGenerator* id_generator_;
SharedLratProofStatus* proof_status_;
std::unique_ptr<LratChecker> lrat_checker_;
std::unique_ptr<LratWriter> lrat_writer_;

View File

@@ -0,0 +1,76 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/sat/lrat_proof_handler.h"
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "gtest/gtest.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/util.h"
namespace operations_research::sat {
namespace {
TEST(AddAndProveInferredClauseByEnumerationTest, XorEquivalence) {
// We assume a = XOR(c,d) and b = XOR(c,d) and we want to prove a => b
const Literal a(+1);
const Literal b(+2);
const Literal c(+3);
const Literal d(+4);
// Encode the two XOR gates.
CompactVectorVector<int, Literal> clauses;
for (const Literal x : {a, b}) {
clauses.Add({c, d, x.Negated()});
clauses.Add({c.Negated(), d, x});
clauses.Add({c, d.Negated(), x});
clauses.Add({c.Negated(), d.Negated(), x.Negated()});
}
// Create the lrat checker.
Model model;
auto* params = model.GetOrCreate<SatParameters>();
params->set_check_lrat_proof(true);
params->set_check_drat_proof(true);
std::unique_ptr<LratProofHandler> lrat =
LratProofHandler::MaybeCreate(&model);
// Lets create ids for all these clauses.
auto* id_generator = model.GetOrCreate<ClauseIdGenerator>();
std::vector<ClauseId> clause_ids;
for (int i = 0; i < clauses.size(); ++i) {
const ClauseId id = id_generator->GetNextId();
clause_ids.push_back(id);
lrat->AddProblemClause(id, clauses[i]);
}
lrat->EndProblemClauses();
// This should be enough to prove equivalence.
{
std::vector<Literal> to_prove = {b.Negated(), a};
EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration(
id_generator->GetNextId(), to_prove, clause_ids, clauses));
}
{
std::vector<Literal> to_prove = {a.Negated(), b};
EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration(
id_generator->GetNextId(), to_prove, clause_ids, clauses));
}
}
} // namespace
} // namespace operations_research::sat

View File

@@ -242,9 +242,17 @@ class AssignmentView {
// The ID of a clause.
DEFINE_STRONG_INT_TYPE(ClauseId, int64_t);
constexpr ClauseId kNoClauseId(0);
// A generator for ClauseIds. Not thread-safe.
class ClauseIdGenerator {
public:
ClauseId GetNextId() { return ClauseId(next_id_++); }
private:
ClauseId next_id_ = ClauseId(1);
};
// Forward declaration.
class SatClause;
class SatPropagator;
@@ -346,12 +354,12 @@ class Trail {
EnqueueHelper(
Literal* trail_ptr, AssignmentInfo* current_info,
AssignmentInfo* info_ptr, VariablesAssignment* assignment,
util_intops::StrongVector<BooleanVariable, ClauseId>* unit_clause_id)
util_intops::StrongVector<BooleanVariable, ClauseId>* clause_ids)
: trail_ptr_(trail_ptr),
current_info_(current_info),
info_ptr_(info_ptr),
bitset_(assignment->GetBitsetView()),
unit_clause_id_(unit_clause_id) {}
clause_ids_(clause_ids) {}
void EnqueueAtLevel(Literal true_literal, int level) {
bitset_.Set(true_literal);
@@ -364,10 +372,10 @@ class Trail {
void EnqueueWithUnitReason(Literal true_literal, ClauseId clause_id) {
DCHECK_NE(clause_id, kNoClauseId);
const BooleanVariable var = true_literal.Variable();
if (var.value() >= unit_clause_id_->size()) {
unit_clause_id_->resize(var.value() + 1, kNoClauseId);
if (var.value() >= clause_ids_->size()) {
clause_ids_->resize(var.value() + 1, kNoClauseId);
}
(*unit_clause_id_)[var] = clause_id;
(*clause_ids_)[var] = clause_id;
bitset_.Set(true_literal);
AssignmentInfo* info = info_ptr_ + true_literal.Variable().value();
@@ -389,12 +397,12 @@ class Trail {
AssignmentInfo* current_info_;
AssignmentInfo* info_ptr_;
Bitset64<LiteralIndex>::View bitset_;
util_intops::StrongVector<BooleanVariable, ClauseId>* unit_clause_id_;
util_intops::StrongVector<BooleanVariable, ClauseId>* clause_ids_;
};
EnqueueHelper GetEnqueueHelper(int propagator_id) {
current_info_.type = propagator_id;
return EnqueueHelper(trail_.data(), &current_info_, info_.data(),
&assignment_, &unit_clause_id_);
&assignment_, &clause_ids_);
}
// Specific Enqueue() for search decisions.
@@ -435,13 +443,7 @@ class Trail {
// Specific Enqueue() version for unit clauses.
void EnqueueWithUnitReason(ClauseId clause_id, Literal true_literal) {
if (clause_id != kNoClauseId) {
const BooleanVariable var = true_literal.Variable();
if (var.value() >= unit_clause_id_.size()) {
unit_clause_id_.resize(var.value() + 1, kNoClauseId);
}
unit_clause_id_[var] = clause_id;
}
MaybeSetClauseId(true_literal, clause_id);
EnqueueAtLevel(true_literal, AssignmentType::kUnitReason, 0);
}
@@ -459,18 +461,21 @@ class Trail {
// Enqueues the given literal using the current content of
// GetEmptyVectorToStoreReason() as the reason. This API is a bit more
// leanient and does not require the literal to be unassigned. If it is
// lenient and does not require the literal to be unassigned. If it is
// already assigned to false, then MutableConflict() will be set appropriately
// and this will return false otherwise this will enqueue the literal and
// returns true.
ABSL_MUST_USE_RESULT bool EnqueueWithStoredReason(Literal true_literal) {
ABSL_MUST_USE_RESULT bool EnqueueWithStoredReason(ClauseId clause_id,
Literal true_literal) {
if (assignment_.LiteralIsTrue(true_literal)) return true;
if (assignment_.LiteralIsFalse(true_literal)) {
*MutableConflict() = reasons_repository_[Index()];
MutableConflict()->push_back(true_literal);
failing_clause_id_ = clause_id;
return false;
}
MaybeSetClauseId(true_literal, clause_id);
Enqueue(true_literal, AssignmentType::kCachedReason);
const BooleanVariable var = true_literal.Variable();
reasons_[var] = reasons_repository_[info_[var].trail_index];
@@ -516,8 +521,17 @@ class Trail {
ClauseId GetUnitClauseId(BooleanVariable var) const {
DCHECK(AssignmentType(var) == AssignmentType::kUnitReason);
DCHECK_EQ(Info(var).level, 0);
if (var.value() >= unit_clause_id_.size()) return kNoClauseId;
return unit_clause_id_[var];
if (var.value() >= clause_ids_.size()) return kNoClauseId;
return clause_ids_[var];
}
// Returns the ID of the clause which is the reason why the given variable was
// enqueued, or kNoClauseId if there is none. The variable must have been
// enqueued with EnqueueWithStoredReason().
ClauseId GetStoredReasonClauseId(BooleanVariable var) const {
DCHECK(AssignmentType(var) == AssignmentType::kCachedReason);
if (var.value() >= clause_ids_.size()) return kNoClauseId;
return clause_ids_[var];
}
// If a variable was propagated with EnqueueWithSameReasonAs(), returns its
@@ -583,6 +597,7 @@ class Trail {
std::vector<Literal>* MutableConflict() {
++conflict_timestamp_;
failing_sat_clause_ = nullptr;
failing_clause_id_ = kNoClauseId;
return &conflict_;
}
@@ -600,9 +615,16 @@ class Trail {
// Specific SatClause interface so we can update the conflict clause activity.
// Note that MutableConflict() automatically sets this to nullptr, so we can
// know whether or not the last conflict was caused by a clause.
void SetFailingSatClause(SatClause* clause) { failing_sat_clause_ = clause; }
void SetFailingSatClause(SatClause* clause) {
failing_sat_clause_ = clause;
failing_clause_id_ = kNoClauseId;
}
SatClause* FailingSatClause() const { return failing_sat_clause_; }
// Returns the LRAT ID of the failing clause. This ID is only set if a
// conflict is detected in EnqueueWithStoredReason().
ClauseId FailingClauseId() const { return failing_clause_id_; }
// Getters.
int NumVariables() const { return trail_.size(); }
int64_t NumberOfEnqueues() const { return num_untrailed_enqueues_ + Index(); }
@@ -664,6 +686,16 @@ class Trail {
private:
ConflictResolutionFunction resolution_;
void MaybeSetClauseId(Literal true_literal, ClauseId clause_id) {
if (clause_id != kNoClauseId) {
const BooleanVariable var = true_literal.Variable();
if (var.value() >= clause_ids_.size()) {
clause_ids_.resize(var.value() + 1, kNoClauseId);
}
clause_ids_[var] = clause_id;
}
}
// Finds all literals between the current trail index and the given one
// assigned at the current level or lower, and re-enqueues them with the same
// reason.
@@ -678,9 +710,11 @@ class Trail {
int64_t conflict_timestamp_ = 0;
std::vector<Literal> conflict_;
util_intops::StrongVector<BooleanVariable, AssignmentInfo> info_;
// The ID of unit clauses (literals enqueued at level 0 with a kUnitReason).
util_intops::StrongVector<BooleanVariable, ClauseId> unit_clause_id_;
// The ID of unit clauses (literals enqueued at level 0 with a kUnitReason)
// and of reason clauses for literals enqueued with a stored reason.
util_intops::StrongVector<BooleanVariable, ClauseId> clause_ids_;
SatClause* failing_sat_clause_;
ClauseId failing_clause_id_;
// Data used by EnqueueWithSameReasonAs().
util_intops::StrongVector<BooleanVariable, BooleanVariable>

View File

@@ -14,6 +14,7 @@
#include "ortools/sat/sat_inprocessing.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <cstdint>
#include <deque>
@@ -37,6 +38,7 @@
#include "ortools/base/timer.h"
#include "ortools/graph/connected_components.h"
#include "ortools/sat/clause.h"
#include "ortools/sat/gate_utils.h"
#include "ortools/sat/linear_programming_constraint.h"
#include "ortools/sat/lrat_proof_handler.h"
#include "ortools/sat/probing.h"
@@ -1916,14 +1918,46 @@ GateCongruenceClosure::~GateCongruenceClosure() {
});
}
template <int arity>
void GateCongruenceClosure::AddToTruthTable(
absl::Span<const Literal> clause,
absl::flat_hash_map<std::array<BooleanVariable, arity>, SmallBitset>&
data) {
CHECK_EQ(clause.size(), arity);
std::array<BooleanVariable, arity> key;
SmallBitset bitmask;
FillKeyAndBitmask(clause, absl::MakeSpan(key), bitmask);
for (const BooleanVariable var : key) {
CHECK(!implication_graph_->IsRemoved(Literal(var, true)));
}
auto [it, inserted] = data.insert({key, bitmask});
if (!inserted) {
const SmallBitset old = it->second;
it->second &= bitmask; // Remove one value.
if (old != it->second) {
// TODO(user): keep id for proof.
}
}
}
// Note that this is the "hot" part of the algo, once we have the and gates,
// the congruence closure should be quite fast.
void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) {
void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
PresolveTimer& timer) {
truth_table3_.clear();
truth_table4_.clear();
std::vector<Literal> candidates;
for (SatClause* clause : clause_manager_->AllClausesInCreationOrder()) {
if (timer.WorkLimitIsReached()) break;
if (clause->size() == 0) continue;
if (clause->size() == 3) {
AddToTruthTable<3>(clause->AsSpan(), truth_table3_);
} else if (clause->size() == 4) {
AddToTruthTable<4>(clause->AsSpan(), truth_table4_);
}
// Used for an optimization below.
int min_num_implications = std::numeric_limits<int>::max();
Literal lit_with_less_implications;
@@ -2030,6 +2064,8 @@ void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) {
// Add the detected gate (its inputs are the negation of each clause
// literal other than the target).
gates_target_.push_back(target.Index());
gates_type_.push_back(kAndGateType);
const int index = gates_inputs_.Add({});
for (const Literal l : clause->AsSpan()) {
if (l == target) continue;
@@ -2048,6 +2084,60 @@ void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) {
// a single base clause of size n will correspond to n and_gates !
}
}
timer.AddCounter("and_gates", gates_inputs_.size());
}
template <int arity>
void GateCongruenceClosure::ExtractShortGates(
PresolveTimer& timer,
const absl::flat_hash_map<std::array<BooleanVariable, arity>, SmallBitset>&
data) {
// For a table on n variables, we look for function x = f(n - 1) variable.
const int num_bits = arity - 1;
// TODO(user): This is non-deterministic order. We need to fix that or
// initially sort the queue of gates to process.
int num_functions = 0;
for (const auto [key, truth_table] : data) {
for (int i = 0; i < arity; ++i) {
if (!IsFunction<arity>(i, truth_table)) continue;
++num_functions;
gates_target_.push_back(Literal(key[i], true));
gates_inputs_.Add({});
for (int j = 0; j < arity; ++j) {
if (i != j) {
gates_inputs_.AppendToLastVector(Literal(key[j], true));
}
}
// Generate the function truth table as a type.
// We will canonicalize it further in the main loop.
unsigned int type = 0;
for (int p = 0; p < (1 << num_bits); ++p) {
// Expand from (arity - 1) bits to (arity) bits.
const int bigger_p = AddHoleAtPosition(i, p);
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
// target is 1 at this position.
type |= 1 << p;
DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function.
} else {
// Note that if there is no feasible assignment for a given p, first
// we could have learned a smaller clause, but also we don't really
// care what is the value of the target at that point p, so we use
// zero.
}
}
gates_type_.push_back(type);
gates_clause_.push_back(nullptr);
}
}
timer.AddCounter(absl::StrCat("table", arity), data.size());
timer.AddCounter(absl::StrCat("fn", num_bits), num_functions);
}
namespace {
@@ -2290,17 +2380,27 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
gates_target_.clear();
gates_inputs_.clear();
gates_type_.clear();
gates_clause_.clear();
// TODO(user): Extract more gates, and encode store their type.
ExtractAndGates(timer);
timer.AddCounter("and_gates", gates_inputs_.size());
ExtractAndGatesAndFillShortTruthTables(timer);
// TODO(user): We currently do not support this with lrat. Fix.
if (lrat_proof_handler_ == nullptr) {
ExtractShortGates<3>(timer, truth_table3_);
ExtractShortGates<4>(timer, truth_table4_);
}
// All vector have the same size.
// Except gates_clause_ which is only filled if we need proof.
CHECK_EQ(gates_target_.size(), gates_type_.size());
CHECK_EQ(gates_target_.size(), gates_inputs_.size());
// If two gates have the same type and the same inputs, their targets are
// equivalent. We use an hash set to detect that the inputs are the same.
absl::flat_hash_set<int, GateHash, GateEq> gate_set(
/*capacity=*/gates_inputs_.size(), GateHash(&gates_inputs_),
GateEq(&gates_inputs_));
/*capacity=*/gates_inputs_.size(), GateHash(&gates_type_, &gates_inputs_),
GateEq(&gates_type_, &gates_inputs_));
// Used to find representatives as we detect equivalent literal.
DenseConnectedComponentsFinder union_find;
@@ -2330,6 +2430,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
int num_units = 0;
int num_processed = 0;
int num_equivalences = 0;
int arity1_equivalences = 0;
while (!queue.empty()) {
++num_processed;
const int id = queue.back();
@@ -2368,13 +2469,13 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
if (is_equivalent) {
CHECK_NE(id, other_id);
CHECK_GE(other_id, 0);
CHECK_EQ(gates_type_[id], gates_type_[other_id]);
CHECK_EQ(absl::Span<const LiteralIndex>(gates_inputs_[id]),
absl::Span<const LiteralIndex>(gates_inputs_[other_id]));
// We detected a <=> b (or, equivalently, rep(a) <=> rep(b)).
const LiteralIndex a = gates_target_[id];
const LiteralIndex b = gates_target_[other_id];
input_literals_to_gate.RemoveFromFutureOutput(id);
if (lrat_proof_handler_ != nullptr) {
lrat_helper.ShortenEquivalencesWithRepresentative(Literal(a));
@@ -2382,6 +2483,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
}
const LiteralIndex rep_a(union_find.FindRoot(a.value()));
const LiteralIndex rep_b(union_find.FindRoot(b.value()));
if (rep_a != rep_b) {
++num_equivalences;
const Literal rep_lit_a(rep_a);
@@ -2401,45 +2503,58 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
return false;
}
union_find.AddEdge(rep_a.value(), rep_b.value());
const LiteralIndex rep(union_find.FindRoot(rep_b.value()));
const LiteralIndex to_merge = rep == rep_a ? rep_b : rep_a;
input_literals_to_gate.MergeInto(to_merge, rep);
if (lrat_proof_handler_ != nullptr) {
if (rep == rep_a) {
lrat_helper.AddGateEquivalenceClauses(
rep_lit_b, rep_b_implies_rep_a, rep_a_implies_rep_b);
} else {
lrat_helper.AddGateEquivalenceClauses(
rep_lit_a, rep_a_implies_rep_b, rep_b_implies_rep_a);
}
}
for (const bool negate : {false, true}) {
const LiteralIndex x =
negate ? Literal(rep_a).NegatedIndex() : rep_a;
const LiteralIndex y =
negate ? Literal(rep_b).NegatedIndex() : rep_b;
// Re-add to the queue all gates with touched inputs.
//
// TODO(user): I think we could only add the gates of "to_merge"
// before we merge. This part of the code is quite quick in any case.
for (const int gate_id : input_literals_to_gate[rep]) {
if (in_queue[gate_id]) continue;
queue.push_back(gate_id);
in_queue[gate_id] = true;
union_find.AddEdge(x.value(), y.value());
const LiteralIndex rep(union_find.FindRoot(y.value()));
const LiteralIndex to_merge = rep == x ? y : x;
input_literals_to_gate.MergeInto(to_merge, rep);
if (lrat_proof_handler_ != nullptr) {
if (rep == x) {
lrat_helper.AddGateEquivalenceClauses(
Literal(y),
negate ? rep_a_implies_rep_b : rep_b_implies_rep_a,
negate ? rep_b_implies_rep_a : rep_a_implies_rep_b);
} else {
lrat_helper.AddGateEquivalenceClauses(
Literal(x),
negate ? rep_b_implies_rep_a : rep_a_implies_rep_b,
negate ? rep_a_implies_rep_b : rep_b_implies_rep_a);
}
}
// Re-add to the queue all gates with touched inputs.
//
// TODO(user): I think we could only add the gates of "to_merge"
// before we merge. This part of the code is quite quick in any
// case.
for (const int gate_id : input_literals_to_gate[rep]) {
if (in_queue[gate_id]) continue;
queue.push_back(gate_id);
in_queue[gate_id] = true;
}
}
}
break;
}
// Canonicalize.
// Canonicalize on pass zero, otherwise loop.
// Note that even if nothing change, we want to run canonicalize at
// least once on the small "truth table" gates.
//
// Note that sorting works for and_gate and any gate that does not depend
// on the order of its inputs. But if we add more fancy functions, we will
// need to be careful.
//
// TODO(user): Because we fix literal, we should also deal with fixed
// literals here. Right now we defer this to a later step, but it might
// reduce the cascading effect of finding more equivalences.
if (pass == 0) {
marked_.ResetAllToFalse();
if (pass > 0) continue;
if (gates_type_[id] == kAndGateType) {
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
marked_.ResetAllToFalse();
int new_size = 0;
bool is_unit = false;
for (const LiteralIndex l : inputs) {
@@ -2454,13 +2569,16 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
if (marked_[Literal(rep).Negated()]) {
is_unit = true;
const Literal to_fix = Literal(gates_target_[id]).Negated();
absl::InlinedVector<ClauseId, 4> clause_ids;
if (lrat_proof_handler_ != nullptr) {
lrat_helper.AppendFixAndGateTargetClauses(id, Literal(rep),
clause_ids);
}
if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) {
return false;
if (!assignment_.LiteralIsTrue(to_fix)) {
absl::InlinedVector<ClauseId, 4> clause_ids;
if (lrat_proof_handler_ != nullptr) {
lrat_helper.AppendFixAndGateTargetClauses(id, Literal(rep),
clause_ids);
}
if (!clause_manager_->InprocessingFixLiteral(to_fix,
clause_ids)) {
return false;
}
}
break;
}
@@ -2480,6 +2598,50 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
CHECK_GT(new_size, 0);
std::sort(inputs.begin(), inputs.begin() + new_size);
gates_inputs_.Shrink(id, new_size);
// Lets convert to the short "type" if we can. The truth table is simply
// a 1 on the last position (where all inputs are ones). We fall back to
// the case below to canonicalize further.
if (new_size > 4 || lrat_proof_handler_ != nullptr) continue;
gates_type_[id] = 1 << ((1 << new_size) - 1);
}
// Generic "short" gates.
// We just take the representative and re-canonicalize.
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
CHECK_GE(gates_type_[id], 0);
CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
for (LiteralIndex& lit_ref : inputs) {
lit_ref = LiteralIndex(union_find.FindRoot(lit_ref.value()));
}
const int new_size = CanonicalizeFunctionTruthTable(
gates_target_[id], inputs, gates_type_[id]);
if (new_size < inputs.size()) {
gates_inputs_.Shrink(id, new_size);
}
if (new_size == 1) {
// We have a function of size 1! this is an equivalence.
//
// TODO(user): deal with it.
++arity1_equivalences;
input_literals_to_gate.RemoveFromFutureOutput(id);
break;
} else if (new_size == 0) {
// We have a fixed function ! just fix the literal.
CHECK(Literal(gates_target_[id]).IsPositive());
const Literal to_fix{Literal(gates_target_[id]).Variable(),
(gates_type_[id] & 1) == 1};
if (!assignment_.LiteralIsTrue(to_fix)) {
absl::InlinedVector<ClauseId, 4> clause_ids;
if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) {
return false;
}
++num_units;
}
input_literals_to_gate.RemoveFromFutureOutput(id);
break;
}
}
}
@@ -2490,6 +2652,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
total_equivalences_ += num_equivalences;
total_num_units_ += num_units;
timer.AddCounter("arity1_equivalences", arity1_equivalences);
timer.AddCounter("units", num_units);
timer.AddCounter("processed", num_processed);
timer.AddCounter("equivalences", num_equivalences);

View File

@@ -30,6 +30,7 @@
#include "absl/types/span.h"
#include "ortools/base/strong_vector.h"
#include "ortools/sat/clause.h"
#include "ortools/sat/gate_utils.h"
#include "ortools/sat/linear_programming_constraint.h"
#include "ortools/sat/lrat_proof_handler.h"
#include "ortools/sat/model.h"
@@ -439,7 +440,8 @@ class BoundedVariableElimination {
class GateCongruenceClosure {
public:
explicit GateCongruenceClosure(Model* model)
: sat_solver_(model->GetOrCreate<SatSolver>()),
: assignment_(model->GetOrCreate<Trail>()->Assignment()),
sat_solver_(model->GetOrCreate<SatSolver>()),
implication_graph_(model->GetOrCreate<BinaryImplicationGraph>()),
clause_manager_(model->GetOrCreate<ClauseManager>()),
clause_id_generator_(model->GetOrCreate<ClauseIdGenerator>()),
@@ -454,23 +456,29 @@ class GateCongruenceClosure {
private:
struct GateHash {
explicit GateHash(CompactVectorVector<int, LiteralIndex>* g)
: gates_inputs(g) {}
explicit GateHash(std::vector<int>* t,
CompactVectorVector<int, LiteralIndex>* g)
: gates_type(t), gates_inputs(g) {}
std::size_t operator()(int gate_id) const {
return absl::HashOf((*gates_inputs)[gate_id]);
return absl::HashOf((*gates_type)[gate_id], (*gates_inputs)[gate_id]);
}
CompactVectorVector<int, LiteralIndex>* gates_inputs;
const std::vector<int>* gates_type;
const CompactVectorVector<int, LiteralIndex>* gates_inputs;
};
struct GateEq {
explicit GateEq(CompactVectorVector<int, LiteralIndex>* g)
: gates_inputs(g) {}
explicit GateEq(std::vector<int>* t,
CompactVectorVector<int, LiteralIndex>* g)
: gates_type(t), gates_inputs(g) {}
bool operator()(int gate_a, int gate_b) const {
if (gate_a == gate_b) return true;
// We use absl::span<> comparison.
return (*gates_inputs)[gate_a] == (*gates_inputs)[gate_b];
return ((*gates_type)[gate_a] == (*gates_type)[gate_b]) &&
((*gates_inputs)[gate_a] == (*gates_inputs)[gate_b]);
}
CompactVectorVector<int, LiteralIndex>* gates_inputs;
const std::vector<int>* gates_type;
const CompactVectorVector<int, LiteralIndex>* gates_inputs;
};
// Recovers "target_literal = and(literals)" from the model.
@@ -483,8 +491,20 @@ class GateCongruenceClosure {
// - for all i, target_literal => literal_i (direct binary implication)
// - all literal at true => target_literal, this is a clause:
// (not(literal[i]) for all i, target_literal).
void ExtractAndGates(PresolveTimer& timer);
void ExtractAndGatesAndFillShortTruthTables(PresolveTimer& timer);
// From possible assignment of "arity" given variables, extract functions.
template <int arity>
void ExtractShortGates(
PresolveTimer& timer,
const absl::flat_hash_map<std::array<BooleanVariable, arity>,
SmallBitset>& data);
template <int arity>
void AddToTruthTable(absl::Span<const Literal> clause,
absl::flat_hash_map<std::array<BooleanVariable, arity>,
SmallBitset>& data);
const VariablesAssignment& assignment_;
SatSolver* sat_solver_;
BinaryImplicationGraph* implication_graph_;
ClauseManager* clause_manager_;
@@ -501,17 +521,33 @@ class GateCongruenceClosure {
// A Boolean gates correspond to target = f(inputs).
//
// Note that the inputs are canonicalized. For and_gates, they are sorted,
// since the gate function does not depend on the order.
// since the gate function does not depend on the order. The type of an
// and_gates is kAndGateType.
//
// TODO(user): for now we have a single gate type. We can easily support more
// by creating an extra std::vector<GateType> and adding that to our
// GateHash/GateEq hash_set.
// Otherwise, we support generic 2 and 3 inputs gates where the type is the
// truth table. i.e. target = type[sum value_of_inputs[i] * 2^i]. For such
// gate, the target and inputs will always be canonicalized to positive and
// sorted literal. We just update the truth table accordingly.
static constexpr int kAndGateType = -1;
std::vector<LiteralIndex> gates_target_;
std::vector<int> gates_type_;
CompactVectorVector<int, LiteralIndex> gates_inputs_;
// For each gate, "the" corresponding clause. For a gate a = and(x,y,...) this
// is the clause "x and y and ... => a". Only used for LRAT.
std::vector<const SatClause*> gates_clause_;
// Map (Xi) (sorted) to a bitmask corresponding to the allowed values.
// We loop over all short clauses to fill this.
//
// TODO(user): Shorter clauses impact larger truth table too and we can
// combine two size 3 to construct a size 4 (needed for ITE-gate).
// not ideal.
absl::flat_hash_map<std::array<BooleanVariable, 3>, SmallBitset>
truth_table3_;
absl::flat_hash_map<std::array<BooleanVariable, 4>, SmallBitset>
truth_table4_;
// For stats.
double total_dtime_ = 0.0;
double total_wtime_ = 0.0;

View File

@@ -127,10 +127,11 @@ message SatParameters {
enum BinaryMinizationAlgorithm {
reserved 2, 3, 4;
NO_BINARY_MINIMIZATION = 0;
BINARY_MINIMIZATION_FIRST = 1;
BINARY_MINIMIZATION_FROM_UIP = 1;
BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS = 5;
}
optional BinaryMinizationAlgorithm binary_minimization_algorithm = 34
[default = BINARY_MINIMIZATION_FIRST];
[default = BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS];
// At a really low cost, during the 1-UIP conflict computation, it is easy to
// detect if some of the involved reasons are subsumed by the current

View File

@@ -974,10 +974,16 @@ void SatSolver::ProcessCurrentConflict(
switch (parameters_->binary_minimization_algorithm()) {
case SatParameters::NO_BINARY_MINIMIZATION:
break;
case SatParameters::BINARY_MINIMIZATION_FIRST:
case SatParameters::BINARY_MINIMIZATION_FROM_UIP:
binary_implication_graph_->MinimizeConflictFirst(
*trail_, &learned_conflict_, &is_marked_,
clause_ids_for_minimization);
clause_ids_for_minimization, /*also_use_decisions=*/false);
break;
case SatParameters::BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS:
binary_implication_graph_->MinimizeConflictFirst(
*trail_, &learned_conflict_, &is_marked_,
clause_ids_for_minimization, /*also_use_decisions=*/true);
break;
}
DCHECK(IsConflictValid(learned_conflict_));
}
@@ -1060,7 +1066,7 @@ void SatSolver::ProcessCurrentConflict(
for (const Literal l : clause) {
if (Assignment().LiteralIsFalse(l)) ++num_false;
}
if (num_false == clause.size()) {
if (num_false == clause.size() || clause.size() == 1) {
int max_level = 0;
for (const Literal l : clause) {
const int level = AssignmentLevel(l.Variable());
@@ -1366,6 +1372,8 @@ void SatSolver::AppendLratProofForFailingClause(
const SatClause* failing_sat_clause = trail_->FailingSatClause();
if (failing_sat_clause != nullptr) {
failing_clause_id = clauses_propagator_->GetClauseId(failing_sat_clause);
} else if (trail_->FailingClauseId() != kNoClauseId) {
failing_clause_id = trail_->FailingClauseId();
} else {
absl::Span<const Literal> failing_clause = trail_->FailingClause();
if (failing_clause.size() == 2) {
@@ -2342,6 +2350,7 @@ std::string SatSolver::RunningStatisticsString() const {
void SatSolver::ProcessNewlyFixedVariables() {
SCOPED_TIME_STAT(&stats_);
DCHECK_EQ(CurrentDecisionLevel(), 0);
if (num_processed_fixed_variables_ == trail_->Index()) return;
int num_detached_clauses = 0;
int num_binary = 0;

View File

@@ -91,17 +91,19 @@ void SharedStatTables::AddSearchStat(absl::string_view name, Model* model) {
void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) {
absl::MutexLock mutex_lock(mutex_);
auto* sat_solver = model->GetOrCreate<SatSolver>();
auto* binary = model->GetOrCreate<BinaryImplicationGraph>();
SatSolver::Counters counters = sat_solver->counters();
if (clauses_table_.empty()) {
clauses_table_.push_back({"SAT stats", "ClassicMinim", "LitRemoved",
"LitLearned", "LitForgotten", "Subsumed",
"MClauses", "MDecisions", "MLitTrue", "MSubsumed",
"MLitRemoved", "MReused"});
"LitRemovedBinary", "LitLearned", "LitForgotten",
"Subsumed", "MClauses", "MDecisions", "MLitTrue",
"MSubsumed", "MLitRemoved", "MReused"});
}
clauses_table_.push_back(
{FormatName(name), FormatCounter(counters.num_minimizations),
FormatCounter(counters.num_literals_removed),
FormatCounter(binary->num_literals_removed()),
FormatCounter(counters.num_literals_learned),
FormatCounter(counters.num_literals_forgotten),
FormatCounter(counters.num_subsumed_clauses),
@@ -117,7 +119,6 @@ void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) {
bool_var_table_.push_back(
{"Boolean variables", "Fixed", "Equiv", "Total", "Left", "Binary"});
}
auto* binary = model->GetOrCreate<BinaryImplicationGraph>();
const int64_t num_fixed = sat_solver->NumFixedVariables();
const int64_t num_equiv = binary->num_redundant_literals() / 2;
const int64_t num_bools = sat_solver->NumVariables();

View File

@@ -607,33 +607,6 @@ void MaxBoundedSubsetSum::AddChoicesInternal(absl::Span<const int64_t> values) {
}
}
int64_t MaxBoundedSubsetSum::MaxIfAdded(int64_t candidate) const {
if (candidate > bound_ || current_max_ == bound_) return current_max_;
int64_t current_max = current_max_;
// Mode 1: vector of all possible sums (with duplicates).
if (!sums_.empty()) {
for (const int64_t v : sums_) {
if (v + candidate > bound_) continue;
if (v + candidate > current_max) {
current_max = v + candidate;
if (current_max == bound_) return current_max;
}
}
return current_max;
}
// Mode 2: bitset of all possible sums.
if (!expanded_sums_.empty()) {
const int64_t min_useful = std::max<int64_t>(0, current_max_ - candidate);
const int64_t max_useful = bound_ - candidate;
for (int64_t v = max_useful; v >= min_useful; --v) {
if (expanded_sums_[v]) return v + candidate;
}
}
return current_max_;
}
BasicKnapsackSolver::Result BasicKnapsackSolver::Solve(
absl::Span<const Domain> domains, absl::Span<const int64_t> coeffs,
absl::Span<const int64_t> costs, const Domain& rhs) {

View File

@@ -550,9 +550,6 @@ class MaxBoundedSubsetSum {
// We look for the maximum sum <= bound.
void Reset(int64_t bound);
// Returns the updated max if value was added to the subset-sum.
int64_t MaxIfAdded(int64_t candidate) const;
// Add a value to the base set for which subset sums will be taken.
void Add(int64_t value);

View File

@@ -737,21 +737,6 @@ TEST(MaxBoundedSubsetSumTest, SimpleMultiChoice) {
EXPECT_EQ(bounded_subset_sum.CurrentMax(), 31);
}
TEST(MaxBoundedSubsetSumTest, CheckMaxIfAdded) {
MaxBoundedSubsetSum bounded_subset_sum(34);
bounded_subset_sum.Add(10);
bounded_subset_sum.Add(10);
bounded_subset_sum.Add(10);
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(12), 32);
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(15), 30);
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(34), 34);
for (int i = 0; i < 100; ++i) {
bounded_subset_sum.Add(18);
}
EXPECT_EQ(bounded_subset_sum.CurrentMax(), 30);
EXPECT_EQ(bounded_subset_sum.MaxIfAdded(5), 33);
}
static void BM_bounded_subset_sum(benchmark::State& state) {
random_engine_t random_;
const int num_items = state.range(0);

View File

@@ -635,18 +635,21 @@ const std::vector<Literal>& SharedTreeWorker::DecisionReason(int level) {
}
bool SharedTreeWorker::AddDecisionImplication(Literal lit, int level) {
CHECK_GT(level, 0);
CHECK_NE(lit.Index(), kNoLiteralIndex);
CHECK(!sat_solver_->Assignment().LiteralIsTrue(lit));
absl::Span<const Literal> reason = DecisionReason(level);
if (sat_solver_->Assignment().LiteralIsFalse(lit)) {
VLOG(2) << "Closing subtree via impl at " << level + 1
<< " assigned=" << assigned_tree_.MaxLevel();
integer_trail_->ReportConflict(DecisionReason(level), {});
trail_->MutableConflict()->assign(reason.begin(), reason.end());
manager_->CloseTree(assigned_tree_, level);
assigned_tree_literals_.clear();
return false;
}
VLOG(2) << "Learned shared clause";
return integer_trail_->EnqueueLiteral(lit, DecisionReason(level), {});
trail_->GetEmptyVectorToStoreReason()->assign(reason.begin(), reason.end());
return trail_->EnqueueWithStoredReason(kNoClauseId, lit);
}
bool SharedTreeWorker::AddImplications() {