diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index c0721320e8..d189b11cd9 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -298,6 +298,29 @@ cc_library( ], ) +cc_library( + name = "gate_utils", + hdrs = ["gate_utils.h"], + deps = [ + ":sat_base", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/types:span", + ], +) + +cc_test( + name = "gate_utils_test", + srcs = ["gate_utils_test.cc"], + deps = [ + ":gate_utils", + ":sat_base", + "//ortools/base:gmock_main", + "//ortools/base:logging", + "@abseil-cpp//absl/random", + "@abseil-cpp//absl/types:span", + ], +) + cc_proto_library( name = "cp_model_cc_proto", visibility = ["//visibility:public"], @@ -1585,6 +1608,7 @@ cc_library( hdrs = ["sat_inprocessing.h"], deps = [ ":clause", + ":gate_utils", ":linear_programming_constraint", ":lrat_proof_handler", ":model", @@ -1939,7 +1963,9 @@ cc_library( hdrs = ["integer.h"], visibility = ["//visibility:public"], deps = [ + ":clause", ":integer_base", + ":lrat_proof_handler", ":model", ":sat_base", ":sat_parameters_cc_proto", @@ -4279,6 +4305,7 @@ cc_library( ":recordio", ":sat_base", ":synchronization", + ":util", "//ortools/base:file", "//ortools/base:intops", "//ortools/base:timer", @@ -4291,6 +4318,19 @@ cc_library( ], ) +cc_test( + name = "lrat_proof_handler_test", + srcs = ["lrat_proof_handler_test.cc"], + deps = [ + ":lrat_proof_handler", + ":model", + ":sat_base", + ":util", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/types:span", + ], +) + cc_test( name = "lrat_checker_test", srcs = ["lrat_checker_test.cc"], diff --git a/ortools/sat/all_different.cc b/ortools/sat/all_different.cc index aab63a6551..2587cc771e 100644 --- a/ortools/sat/all_different.cc +++ b/ortools/sat/all_different.cc @@ -404,7 +404,7 @@ bool AllDifferentConstraint::Propagate() { } } - return trail_->EnqueueWithStoredReason(x_lit.Negated()); + return trail_->EnqueueWithStoredReason(kNoClauseId, x_lit.Negated()); } } } diff --git a/ortools/sat/circuit.cc b/ortools/sat/circuit.cc index f18867cdc9..e81549e5ac 100644 --- a/ortools/sat/circuit.cc +++ b/ortools/sat/circuit.cc @@ -308,7 +308,7 @@ bool CircuitPropagator::Propagate() { std::vector* reason = trail_.GetEmptyVectorToStoreReason(); FillReasonForPath(start_node, reason); enforcement_helper_.AddEnforcementReason(enforcement_id_, reason); - if (!trail_.EnqueueWithStoredReason(literal.Negated())) { + if (!trail_.EnqueueWithStoredReason(kNoClauseId, literal.Negated())) { return false; } } @@ -357,7 +357,8 @@ bool CircuitPropagator::Propagate() { if (extra_reason != kFalseLiteralIndex) { reason->push_back(Literal(extra_reason)); } - const bool ok = trail_.EnqueueWithStoredReason(literal.Negated()); + const bool ok = + trail_.EnqueueWithStoredReason(kNoClauseId, literal.Negated()); if (!ok) return false; continue; } @@ -398,7 +399,7 @@ bool CircuitPropagator::Propagate() { std::vector* reason = trail_.GetEmptyVectorToStoreReason(); FillReasonForPath(start_node, reason); enforcement_helper_.AddEnforcementReason(enforcement_id_, reason); - const bool ok = trail_.EnqueueWithStoredReason(literal); + const bool ok = trail_.EnqueueWithStoredReason(kNoClauseId, literal); if (!ok) return false; } else { trail_.EnqueueWithSameReasonAs(literal, variable_with_same_reason); @@ -669,8 +670,8 @@ bool CircuitCoveringPropagator::Propagate() { !trail_->Assignment().LiteralIsFalse(graph_[end][start])) { auto* reason = trail_->GetEmptyVectorToStoreReason(); FillFixedPathInReason(start, end, reason); - const bool ok = - trail_->EnqueueWithStoredReason(graph_[end][start].Negated()); + const bool ok = trail_->EnqueueWithStoredReason( + kNoClauseId, graph_[end][start].Negated()); if (!ok) return false; } } diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index 7e27074d92..d5c13e789e 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -477,20 +477,7 @@ bool ClauseManager::InprocessingAddUnitClause(ClauseId unit_clause_id, bool ClauseManager::InprocessingFixLiteral( Literal true_literal, absl::Span clause_ids) { - DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); - if (trail_->Assignment().LiteralIsTrue(true_literal)) return true; - - ClauseId clause_id = kNoClauseId; - if (lrat_proof_handler_ != nullptr) { - clause_id = clause_id_generator_->GetNextId(); - lrat_proof_handler_->AddInferredClause(clause_id, {true_literal}, - clause_ids); - } - trail_->EnqueueWithUnitReason(clause_id, true_literal); - - // Even when all clauses are detached, we can propagate the implication - // graph and we do that right away. - return implication_graph_->Propagate(trail_); + return implication_graph_->FixLiteral(true_literal, clause_ids); } void ClauseManager::ChangeLbdIfBetter(SatClause* clause, int new_lbd) { @@ -943,21 +930,17 @@ bool BinaryImplicationGraph::AddBinaryClauseInternal( ClauseId id, Literal a, Literal b, bool change_reason, bool delete_non_representative_id) { SCOPED_TIME_STAT(&stats_); + DCHECK_GE(a.Index(), 0); + DCHECK_GE(b.Index(), 0); // Tricky: If this is the first clause, the propagator will be added and // assumed to be in a "propagated" state. This makes sure this is the case. if (no_constraint_ever_added_) propagation_trail_index_ = trail_->Index(); no_constraint_ever_added_ = false; - Literal rep_a = a; - Literal rep_b = b; ClauseId rep_id = kNoClauseId; - if (is_redundant_[a.Index()]) { - rep_a = Literal(representative_of_[a.Index()]); - } - if (is_redundant_[b.Index()]) { - rep_b = Literal(representative_of_[b.Index()]); - } + const Literal rep_a = RepresentativeOf(a); + const Literal rep_b = RepresentativeOf(b); if (rep_a == rep_b.Negated()) return true; if (lrat_proof_handler_ != nullptr) { @@ -1361,15 +1344,50 @@ void BinaryImplicationGraph::Reimply(Trail* trail, int old_trail_index) { // can take advantage of that. void BinaryImplicationGraph::MinimizeConflictFirst( const Trail& trail, std::vector* conflict, - SparseBitset* marked, std::vector* clause_ids) { + SparseBitset* marked, std::vector* clause_ids, + bool also_use_decisions) { SCOPED_TIME_STAT(&stats_); DCHECK(!conflict->empty()); is_marked_.ClearAndResize(LiteralIndex(implications_and_amos_.size())); + + tmp_to_keep_.clear(); + tmp_to_keep_.push_back(conflict->front().Negated()); if (lrat_proof_handler_ != nullptr) { MarkDescendants(conflict->front().Negated()); } else { MarkDescendants(conflict->front().Negated()); } + + // Because the decision cannot be removed from the conflict, we know they will + // stay, so it is okay to see what they propagate in the binary implication + // graph. Technically we could do that for any first literal of a decision + // level. Improve? + std::vector> decisions; + if (also_use_decisions) { + for (int i = 1; i < conflict->size(); ++i) { + const auto& info = trail_->Info((*conflict)[i].Variable()); + if (info.type == AssignmentType::kSearchDecision) { + decisions.push_back({(*conflict)[i].Negated(), info.level}); + } + } + absl::c_stable_sort(decisions, [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); + } + + // Keep marking everything propagated by the decisions, and make sure we + // don't remove the one from which we called MarkDescendants(). + for (const auto [literal, unused_level] : decisions) { + if (is_marked_[literal]) continue; + tmp_to_keep_.push_back(literal); + if (lrat_proof_handler_ != nullptr) { + MarkDescendants(literal); + } else { + MarkDescendants(literal); + } + } + for (const LiteralIndex i : is_marked_.PositionsSetAtLeastOnce()) { // TODO(user): if this is false, then we actually have a conflict of size 2. // This can only happen if the binary clause was not propagated properly @@ -1379,6 +1397,9 @@ void BinaryImplicationGraph::MinimizeConflictFirst( marked->Set(Literal(i).Variable()); } } + + // Remove all marked literals from the conflict. + for (const Literal l : tmp_to_keep_) is_marked_.Clear(l); if (lrat_proof_handler_ != nullptr) { RemoveRedundantLiterals(conflict, clause_ids); } else { diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index 0324b54c76..3c611c8abb 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -46,15 +46,6 @@ namespace operations_research { namespace sat { -// A generator for ClauseIds. Not thread-safe. -class ClauseIdGenerator { - public: - ClauseId GetNextId() { return ClauseId(next_id_++); } - - private: - ClauseId next_id_ = ClauseId(1); -}; - // This is how the SatSolver stores a clause. A clause is just a disjunction of // literals. In many places, we just use vector to encode one. But in // the critical propagation code, we use this class to remove one memory @@ -709,7 +700,8 @@ class BinaryImplicationGraph : public SatPropagator { // details about the different algorithms. void MinimizeConflictFirst(const Trail& trail, std::vector* c, SparseBitset* marked, - std::vector* clause_ids); + std::vector* clause_ids, + bool also_use_decisions); // Appends the IDs of the unit and binary clauses that imply the given // literals to `clause_ids`. Either `MinimizeConflictFirst` or @@ -985,6 +977,12 @@ class BinaryImplicationGraph : public SatPropagator { return implications_and_amos_[lit].offsets(); } + // Simple wrapper to not forget to output newly fixed variable to the DRAT + // or LRAT proof (with clause_ids as proof) if needed. This will propagate + // right away the implications. + bool FixLiteral(Literal true_literal, + absl::Span clause_ids = {}); + private: friend class LratEquivalenceHelper; @@ -1005,12 +1003,6 @@ class BinaryImplicationGraph : public SatPropagator { clause_id_[{a, b}] = id; } - // Simple wrapper to not forget to output newly fixed variable to the DRAT - // or LRAT proof (with clause_ids as proof) if needed. This will propagate - // right away the implications. - bool FixLiteral(Literal true_literal, - absl::Span clause_ids = {}); - // Removes any literal whose negation is marked (except the first one). If // `fill_clause_ids` is true, fills the LRAT proof for this change in // `clause_ids` (this requires an LRAT proof handler to be set, and @@ -1117,12 +1109,14 @@ class BinaryImplicationGraph : public SatPropagator { int64_t num_redundant_literals_ = 0; // Bitset used by MinimizeClause(). + // // TODO(user): use the same one as the one used in the classic minimization // because they are already initialized. Moreover they contains more - // information. + // information? SparseBitset is_marked_; SparseBitset tmp_bitset_; SparseBitset is_simplified_; + std::vector tmp_to_keep_; // Used by AppendImplicationChains() to avoid processing a unit clause several // times. diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 9d57435469..a448c0a4d6 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -858,9 +858,6 @@ std::vector GetFullWorkerParameters( ModelHasSchedulingConstraints(cp_model); // Our current set of strategies - // - // TODO(user): Avoid launching two strategies if they are the same, - // like if there is no lp, or everything is already linearized at level 1. std::vector names; // Starts by adding user specified ones. @@ -1021,7 +1018,15 @@ std::vector GetFullWorkerParameters( if (result.size() > num_to_keep) { result.resize(std::max(0, num_to_keep)); + } else if (!result.empty() && num_to_keep >= 0) { + // If we have less parameters, duplicate the first one until we have enough. + // This is a bit hacky but easily allow to do experiment with n times the + // same subsolver. + while (result.size() < num_to_keep) { + result.push_back(result[0]); + } } + return result; } diff --git a/ortools/sat/gate_utils.h b/ortools/sat/gate_utils.h new file mode 100644 index 0000000000..18df423852 --- /dev/null +++ b/ortools/sat/gate_utils.h @@ -0,0 +1,199 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Functions to manipulate a "small" truth table where +// f(X0, X1, X2) is true iff bitmask[X0 + (X1 << 1) + (X2 << 2)] is true. +#ifndef ORTOOLS_SAT_GATE_UTILS_H_ +#define ORTOOLS_SAT_GATE_UTILS_H_ + +#include + +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "ortools/sat/sat_base.h" + +namespace operations_research::sat { + +using SmallBitset = uint32_t; + +// Sort the key and modify the truth table accordingly. +// +// Note that we don't deal with identical key here, but the function +// CanonicalizeFunctionTruthTable() does, and that is sufficient for our use +// case. +template +void CanonicalizeTruthTable(absl::Span key, + SmallBitset& bitmask) { + const int num_bits = key.size(); + CHECK_EQ(bitmask >> (1 << num_bits), 0); + for (int i = 0; i < num_bits; ++i) { + for (int j = i + 1; j < num_bits; ++j) { + if (key[i] <= key[j]) continue; + + std::swap(key[i], key[j]); + + // We need to swap bit positions i and j in bitmask. + SmallBitset new_bitmask = 0; + for (int p = 0; p < (1 << num_bits); ++p) { + const int value_i = (p >> i) & 1; + const int value_j = (p >> j) & 1; + int new_p = p; + new_p ^= (value_i << i) ^ (value_j << j); // Clear. + new_p ^= (value_i << j) ^ (value_j << i); // Swap. + new_bitmask |= ((bitmask >> p) & 1) << new_p; + } + bitmask = new_bitmask; + CHECK_EQ(bitmask >> (1 << num_bits), 0) + << i << " " << j << " " << num_bits; + } + } + CHECK(std::is_sorted(key.begin(), key.end())); +} + +// Given a clause, return the truth table corresponding to it. +// Namely, a single value should be excluded. +inline void FillKeyAndBitmask(absl::Span clause, + absl::Span key, + SmallBitset& bitmask) { + CHECK_EQ(clause.size(), key.size()); + const int num_bits = clause.size(); + bitmask = ~SmallBitset(0) >> (32 - (1 << num_bits)); // All allowed; + CHECK_EQ(bitmask >> (1 << num_bits), 0) << num_bits; + SmallBitset bit_to_remove = 0; + for (int i = 0; i < num_bits; ++i) { + key[i] = clause[i].Variable(); + bit_to_remove |= (clause[i].IsPositive() ? 0 : 1) << i; + } + bitmask ^= (SmallBitset(1) << bit_to_remove); + CHECK_EQ(bitmask >> (1 << num_bits), 0) << bit_to_remove << " " << num_bits; + CanonicalizeTruthTable(key, bitmask); +} + +// Returns true iff the truth table encoded in bitmask encode a function +// Xi = f(Xj, j != i); +template +bool IsFunction(int i, SmallBitset truth_table) { + DCHECK_GE(i, 0); + DCHECK_LT(i, num_bits); + + // We need to check that there is never two possibilities for Xi. + for (int p = 0; p < (1 << num_bits); ++p) { + if ((truth_table >> p) & (truth_table >> (p ^ (1 << i))) & 1) return false; + } + + return true; +} + +inline int AddHoleAtPosition(int i, int bitset) { + return (bitset & ((1 << i) - 1)) + ((bitset >> i) << (i + 1)); +} + +// The function is target = function_values[inputs as bit position]. +// +// TODO(user): This can be optimized with more bit twiddling if needed. +inline int CanonicalizeFunctionTruthTable(LiteralIndex& target, + absl::Span inputs, + int& int_function_values) { + // We want to fit on an int. + CHECK_LE(inputs.size(), 4); + + // We assume smaller type. + SmallBitset function_values = int_function_values; + + const int num_bits = inputs.size(); + CHECK_LE(num_bits, 4); // Truth table must fit on an int. + const SmallBitset all_one = (1 << (1 << num_bits)) - 1; + CHECK_EQ(function_values & ~all_one, 0); + + // Make sure target is positive. + if (!Literal(target).IsPositive()) { + target = Literal(target).Negated(); + function_values = function_values ^ all_one; + CHECK_EQ(function_values >> (1 << num_bits), 0); + } + + // Make sure all inputs are positive. + for (int i = 0; i < num_bits; ++i) { + if (Literal(inputs[i]).IsPositive()) continue; + + inputs[i] = Literal(inputs[i]).NegatedIndex(); + + // Position p go to position (p ^ (1 << i)). + SmallBitset new_truth_table = 0; + const SmallBitset to_xor = 1 << i; + for (int p = 0; p < (1 << num_bits); ++p) { + new_truth_table |= ((function_values >> p) & 1) << (p ^ to_xor); + } + function_values = new_truth_table; + CHECK_EQ(function_values >> (1 << num_bits), 0); + } + + // Sort the inputs now. + CanonicalizeTruthTable(inputs, function_values); + CHECK_EQ(function_values >> (1 << num_bits), 0); + + // Merge identical variables. + for (int i = 0; i < inputs.size(); ++i) { + for (int j = i + 1; j < inputs.size();) { + if (inputs[i] == inputs[j]) { + // Lets remove input j. + for (int k = j; k + 1 < inputs.size(); ++k) inputs[k] = inputs[k + 1]; + inputs.remove_suffix(1); + + SmallBitset new_truth_table = 0; + for (int p = 0; p < (1 << inputs.size()); ++p) { + int extended_p = AddHoleAtPosition(j, p); + extended_p |= ((p >> i) & 1) << j; // fill it with bit i. + new_truth_table |= ((function_values >> extended_p) & 1) << p; + } + function_values = new_truth_table; + } else { + ++j; + } + } + } + + // Lower arity? + // This can happen if the output do not depend on one of the inputs. + for (int i = 0; i < inputs.size();) { + bool remove = true; + for (int p = 0; p < (1 << inputs.size()); ++p) { + if (((function_values >> p) & 1) != + ((function_values >> (p ^ (1 << i))) & 1)) { + remove = false; + break; + } + } + if (remove) { + // Lets remove input i. + for (int k = i; k + 1 < inputs.size(); ++k) inputs[k] = inputs[k + 1]; + inputs.remove_suffix(1); + + SmallBitset new_truth_table = 0; + for (int p = 0; p < (1 << inputs.size()); ++p) { + const int extended_p = AddHoleAtPosition(i, p); + new_truth_table |= ((function_values >> extended_p) & 1) << p; + } + function_values = new_truth_table; + } else { + ++i; + } + } + + int_function_values = function_values; + return inputs.size(); +} + +} // namespace operations_research::sat + +#endif // ORTOOLS_SAT_GATE_UTILS_H_ diff --git a/ortools/sat/gate_utils_test.cc b/ortools/sat/gate_utils_test.cc new file mode 100644 index 0000000000..f4833ebaa3 --- /dev/null +++ b/ortools/sat/gate_utils_test.cc @@ -0,0 +1,147 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/gate_utils.h" + +#include +#include +#include +#include + +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/base/logging.h" +#include "ortools/sat/sat_base.h" + +namespace operations_research::sat { +namespace { + +TEST(CanonicalizeTruthTableTest, BasicBehavior1) { + std::array key = {0, 2, 1}; + + // no change here. + SmallBitset bitmask = 0b10101010; + CanonicalizeTruthTable(absl::MakeSpan(key), bitmask); + EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b10101010)); +} + +TEST(CanonicalizeTruthTableTest, BasicBehavior2) { + std::array key = {2, 0, 1}; + SmallBitset bitmask = 0b10101010; + CanonicalizeTruthTable(absl::MakeSpan(key), bitmask); + EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11110000)); +} + +TEST(CanonicalizeTruthTableTest, BasicBehavior3) { + std::array key = {1, 0, 2}; + SmallBitset bitmask = 0b10101010; + CanonicalizeTruthTable(absl::MakeSpan(key), bitmask); + EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11001100)); +} + +TEST(FillKeyAndBitmaskTest, BasicBehavior1) { + std::array key; + SmallBitset bitmask; + FillKeyAndBitmask({Literal(+1), Literal(-2), Literal(+3)}, + absl::MakeSpan(key), bitmask); + EXPECT_THAT(key, + ::testing::ElementsAre(BooleanVariable(0), BooleanVariable(1), + BooleanVariable(2))); + // The bit number 2 = 0b010 should be off. + EXPECT_EQ(std::bitset<8>(bitmask), std::bitset<8>(0b11111011)); +} + +TEST(IsFunctionTest, ConstantValue) { + EXPECT_TRUE(IsFunction<3>(0, 0b10101010)); + EXPECT_FALSE(IsFunction<3>(1, 0b10101010)); + EXPECT_FALSE(IsFunction<3>(2, 0b10101010)); +} + +TEST(AddHoleAtPositionTest, BasicTest) { + EXPECT_EQ(AddHoleAtPosition(0, 0xFF), 0b111111110); + EXPECT_EQ(AddHoleAtPosition(1, 0xFF), 0b111111101); + EXPECT_EQ(AddHoleAtPosition(8, 0xFF), 0b011111111); +} + +TEST(CanonicalizeFunctionTruthTableTest, RandomTest) { + absl::BitGen random; + const int num_vars = 8; + + for (int num_test = 0; num_test < 1000; ++num_test) { + // Lets generate a random function on k random variables. + const int k = absl::Uniform(random, 0, 4); + const int table = absl::Uniform(random, 0, 1 << (1 << k)); + const Literal output(BooleanVariable(100), absl::Bernoulli(random, 0.5)); + std::vector inputs; + for (int i = 0; i < k; ++i) { + inputs.push_back( + Literal(BooleanVariable(absl::Uniform(random, 0, num_vars)), + absl::Bernoulli(random, 0.5))); + } + + LiteralIndex new_output_index = output.Index(); + std::vector new_inputs = inputs; + std::vector new_inputs_index; + for (const Literal lit : new_inputs) { + new_inputs_index.push_back(lit.Index()); + } + int new_table = table; + const int new_size = CanonicalizeFunctionTruthTable( + new_output_index, absl::MakeSpan(new_inputs_index), new_table); + new_inputs_index.resize(new_size); + new_inputs.clear(); + for (const LiteralIndex lit_index : new_inputs_index) { + new_inputs.push_back(Literal(lit_index)); + } + + // Log before potential failure. + LOG(INFO) << "IN arity=" << k << " " << output << " = f(" << inputs << ") " + << std::bitset<16>(table); + LOG(INFO) << "OUT arity=" << new_size << " " << Literal(new_output_index) + << " = f(" << new_inputs << ") " << std::bitset<16>(new_table); + + // Now check that both function always take the same value. + for (int m = 0; m < (1 << num_vars); ++m) { + int index = 0; + for (int i = 0; i < inputs.size(); ++i) { + const Literal lit = inputs[i]; + int value = (m >> lit.Variable().value()) & 1; + if (!lit.IsPositive()) value = 1 - value; + index |= value << i; + } + const int target_value = (table >> index) & 1; + + int new_index = 0; + for (int i = 0; i < new_inputs.size(); ++i) { + const Literal lit = new_inputs[i]; + int value = (m >> lit.Variable().value()) & 1; + if (!lit.IsPositive()) value = 1 - value; + new_index |= value << i; + } + const int new_target_value = (new_table >> new_index) & 1; + + if (output == Literal(new_output_index)) { + ASSERT_EQ(target_value, new_target_value) << index << " " << new_index; + } else { + ASSERT_EQ(output, Literal(new_output_index).Negated()); + ASSERT_EQ(target_value, 1 - new_target_value) + << index << " " << new_index; + } + } + } +} + +} // namespace +} // namespace operations_research::sat diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 087a4e72f8..96fdebf1e5 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -34,7 +35,9 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/base/strong_vector.h" +#include "ortools/sat/clause.h" #include "ortools/sat/integer_base.h" +#include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -105,6 +108,8 @@ class IntegerEncoder { explicit IntegerEncoder(Model* model) : sat_solver_(model->GetOrCreate()), trail_(model->GetOrCreate()), + clause_id_generator_(model->GetOrCreate()), + lrat_proof_handler_(model->Mutable()), delayed_to_fix_(model->GetOrCreate()), domains_(*model->GetOrCreate()), num_created_variables_(0) {} @@ -291,9 +296,15 @@ class IntegerEncoder { Literal(sat_solver_->NewBooleanVariable(), true); literal_index_true_ = literal_true.Index(); - // This might return false if we are already UNSAT. - // TODO(user): Make sure we abort right away on unsat! - (void)sat_solver_->AddUnitClause(literal_true); + ClauseId clause_id = kNoClauseId; + if (lrat_proof_handler_ != nullptr) { + clause_id = clause_id_generator_->GetNextId(); + // We cannot prove `literal_true` by unit propagation, but we can with a + // RAT inference (trivial here since there are no clauses containing the + // negation of the pivot `literal_true`). + lrat_proof_handler_->AddInferredClause(clause_id, {literal_true}, {}); + } + trail_->EnqueueWithUnitReason(clause_id, literal_true); } return Literal(literal_index_true_); } @@ -328,6 +339,8 @@ class IntegerEncoder { SatSolver* sat_solver_; Trail* trail_; + ClauseIdGenerator* clause_id_generator_; + LratProofHandler* lrat_proof_handler_; DelayedRootLevelDeduction* delayed_to_fix_; const IntegerDomains& domains_; @@ -339,7 +352,7 @@ class IntegerEncoder { // corresponding to the same variable). // // Note that we only keep this for positive variable. - // The one for the negation can be infered by it. + // The one for the negation can be inferred by it. // // Like x >= 1 x >= 4 x >= 5 // Correspond to x <= 0 x <= 3 x <= 4 diff --git a/ortools/sat/integer_resolution.cc b/ortools/sat/integer_resolution.cc index 36ed6b2cec..de7f5a1db8 100644 --- a/ortools/sat/integer_resolution.cc +++ b/ortools/sat/integer_resolution.cc @@ -512,8 +512,8 @@ void IntegerConflictResolution::ComputeFirstUIPConflict( } if (uip_found) { - if (params_.binary_minimization_algorithm() == - SatParameters::BINARY_MINIMIZATION_FIRST) { + if (params_.binary_minimization_algorithm() != + SatParameters::NO_BINARY_MINIMIZATION) { if (conflict->empty()) { // This one will always stay in the conflict, even after // minimization. So we can use it to minimize the conflict and avoid diff --git a/ortools/sat/lrat_proof_handler.cc b/ortools/sat/lrat_proof_handler.cc index b772e44274..f5ede80c9b 100644 --- a/ortools/sat/lrat_proof_handler.cc +++ b/ortools/sat/lrat_proof_handler.cc @@ -14,6 +14,7 @@ #include "ortools/sat/lrat_proof_handler.h" #include +#include #include #include #include @@ -23,6 +24,7 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "absl/flags/flag.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -40,6 +42,7 @@ #include "ortools/sat/recordio.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" #if defined(_MSC_VER) ABSL_FLAG(std::string, cp_model_drat_output, ".\\drat.txt", @@ -423,6 +426,7 @@ std::unique_ptr LratProofHandler::MaybeCreate(Model* model) { LratProofHandler::LratProofHandler(Model* model) : id_(model->GetOrCreate()->NewSubSolverId()), + id_generator_(model->GetOrCreate()), proof_status_(model->GetOrCreate()) { const SatParameters& params = *model->GetOrCreate(); if (params.check_lrat_proof()) { @@ -448,7 +452,7 @@ LratProofHandler::LratProofHandler(Model* model) bool LratProofHandler::AddProblemClause(ClauseId id, absl::Span clause) { - VLOG(1) << "AddProblemClause: id=" << id + VLOG(2) << "AddProblemClause: id=" << id << " literals=" << absl::StrJoin(clause, ","); if (all_problem_clauses_loaded_ && debug_crash_on_error_) { LOG(FATAL) << "LRAT error: problem clauses must not be added after " @@ -481,7 +485,7 @@ bool LratProofHandler::AddInferredClause( ClauseId id, absl::Span clause, absl::Span unit_ids, absl::Span rat) { - VLOG(1) << "AddInferredClause: id=" << id + VLOG(2) << "AddInferredClause: id=" << id << " literals=" << absl::StrJoin(clause, ",") << " unit_ids=" << absl::StrJoin(unit_ids, ",") << " rat={" << absl::StrJoin(rat, " ") << "}"; @@ -508,7 +512,7 @@ bool LratProofHandler::AddInferredClause( bool LratProofHandler::AddImportedClause(ClauseId id, absl::Span clause) { - VLOG(1) << "AddImportedClause: id=" << id + VLOG(2) << "AddImportedClause: id=" << id << " literals=" << absl::StrJoin(clause, ","); if (lrat_checker_ != nullptr && !lrat_checker_->AddProblemClause(id, clause)) { @@ -526,7 +530,7 @@ bool LratProofHandler::AddImportedClause(ClauseId id, bool LratProofHandler::AddAssumedClause(ClauseId id, absl::Span clause) { - VLOG(1) << "AddAssumedClause: id=" << id + VLOG(2) << "AddAssumedClause: id=" << id << " literals=" << absl::StrJoin(clause, ","); if (debug_crash_on_error_) { LOG(FATAL) << "LRAT error: assumed clauses are not supposed to happen"; @@ -571,7 +575,7 @@ void LratProofHandler::DeleteClause(ClauseId id, delete_pinned_clause_ = true; return; } - VLOG(1) << "DeleteClause: id=" << id + VLOG(2) << "DeleteClause: id=" << id << " literals=" << absl::StrJoin(clause, ","); if (lrat_checker_ != nullptr) { lrat_checker_->DeleteClauses({id}); @@ -638,5 +642,177 @@ void LratProofHandler::Close(bool model_is_unsat) { } } +bool LratProofHandler::AddAndProveInferredClauseByEnumeration( + ClauseId new_id, absl::Span new_clause, + absl::Span ids_for_proof, + const CompactVectorVector& clauses_for_proof) { + CHECK_EQ(ids_for_proof.size(), clauses_for_proof.size()); + CHECK(!clauses_for_proof.empty()); + + // First we count the number of variables appearing and have a separate dense + // index for them. The first new_clause.size() dense index are exactly the + // literal of the new_clause. + absl::flat_hash_map to_dense_index; + std::vector dense_index_to_literal; + dense_index_to_literal.assign(new_clause.begin(), new_clause.end()); + for (const Literal lit : new_clause) { + const auto [it, inserted] = + to_dense_index.insert({lit.Variable(), to_dense_index.size()}); + if (!inserted) { + VLOG(2) << "Duplicate variable in new_clause! " << new_clause; + return false; + } + } + + // Then any new BooleanVariable appearing get the next dense index. + std::vector relevant_literals; + for (int i = 0; i < clauses_for_proof.size(); ++i) { + for (const Literal lit : clauses_for_proof[i]) { + const auto [it, inserted] = + to_dense_index.insert({lit.Variable(), to_dense_index.size()}); + if (inserted) { + relevant_literals.push_back(lit); + } + } + } + + // Too many variables. + // + // TODO(user): The limit can be increased a bit if needed. + if (to_dense_index.size() > 6) { + VLOG(2) << "Too many variables"; + return false; + } + + // For the proof we will need all clauses of the form + // {new_clause, l0, ..., lk} for all k in [0, n) and + // li = relevant_literals[i] OR relevant_literals[i].Negated(). + // + // That give us 2^(n + 1) intermediate clauses. + // Their ids will be stored in (1 << k) + binary_encoding_of_the_li. + const int n = to_dense_index.size() - new_clause.size(); + CHECK_EQ(n, relevant_literals.size()); + const int num_intermediates = 1 << (n + 1); + std::vector ids(num_intermediates, kNoClauseId); + + if (n == 0) { + VLOG(2) << "Nothing to prove! An existing clause is included inside"; + return false; + } + + VLOG(2) << "Starting proof n= " << n << " " << num_intermediates; + + // Any initial clause can be used to prove all the intermediates that contains + // it. Note that this code supports duplicate literals in the clauses. + for (int i = 0; i < clauses_for_proof.size(); ++i) { + bool skip = false; + int base_index = 0; + int mask = 0; + int k = 0; + for (const Literal lit : clauses_for_proof[i]) { + const int dense_index = to_dense_index[lit.Variable()]; + if (dense_index < new_clause.size()) { + // Check that the literal is the same as in the new_clause, if + // not, this clause will not be needed for the proof. + if (lit != new_clause[dense_index]) { + skip = true; + break; + } + } else { + k = std::max(k, dense_index); + mask |= 1 << dense_index; + if (lit == relevant_literals[dense_index - new_clause.size()]) { + base_index |= 1 << dense_index; + } + } + } + if (skip) continue; + + mask >>= new_clause.size(); + base_index >>= new_clause.size(); + k = k + 1 - new_clause.size(); + + VLOG(2) << k << " " << std::bitset<8>(mask) << " " + << std::bitset<8>(base_index); + + // TODO(user): we could be faster here if it become needed. + for (int m = 0; m < (1 << n); ++m) { + if ((m & mask) != base_index) continue; // not included. + const int index = m | base_index; + for (int j = k; j <= n; ++j) { + if (index >> j == 0) { + VLOG(2) << "Included in " << j << " " + << std::bitset<8>((1 << j) | index); + ids[(1 << j) | index] = ids_for_proof[i]; + } + } + } + } + + // We can prove the others by decreasing k. + std::vector tmp_clause; + tmp_clause.assign(new_clause.begin(), new_clause.end()); + std::vector id_need_deletion(num_intermediates, false); + for (int k = n; --k >= 0;) { + for (int m = 0; m < (1 << k); ++m) { + const int index = (1 << k) | m; + if (ids[index] != kNoClauseId) continue; // Already proven. + + // Generate the tmp_clause. + tmp_clause.resize(new_clause.size()); + for (int i = 0; i < k; ++i) { + tmp_clause.push_back(relevant_literals[i]); + if (((index >> i) & 1) == 0) { + tmp_clause.back() = tmp_clause.back().Negated(); + } + } + + // Prove it from the two clauses at k + 1. + const int higher1 = index ^ ((0b11) << k); + const int higher2 = index ^ ((0b10) << k); + const ClauseId id1 = ids[higher1]; + const ClauseId id2 = ids[higher2]; + if (id1 == kNoClauseId || id2 == kNoClauseId) { + VLOG(2) << "missing higher level clauses in the resolution." + << " index: " << std::bitset<8>(index) + << " higher1: " << std::bitset<8>(higher1) + << " higher2: " << std::bitset<8>(higher2); + return false; + } + + ids[index] = k == 0 ? new_id : id_generator_->GetNextId(); + if (k != 0) { + VLOG(2) << "temporary !! " << ids[index] << " " << tmp_clause; + id_need_deletion[index] = true; // temporary. + } + if (!AddInferredClause(ids[index], tmp_clause, {id1, id2})) { + VLOG(2) << "Failed resolution step"; + return false; + } + + if (k == 0) { + DCHECK_EQ(new_clause, tmp_clause); + VLOG(2) << "Proven " << new_clause << "!"; + } + + // Lets delete the ids if they were temporary. + if (id_need_deletion[higher1]) { + tmp_clause.push_back(relevant_literals[k].Negated()); + VLOG(2) << "deleting: " << id1 << " " << tmp_clause; + DeleteClause(id1, tmp_clause); + tmp_clause.pop_back(); + } + if (id_need_deletion[higher2]) { + tmp_clause.push_back(relevant_literals[k]); + VLOG(2) << "deleting: " << id2 << " " << tmp_clause; + DeleteClause(id2, tmp_clause); + tmp_clause.pop_back(); + } + } + } + + return true; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/lrat_proof_handler.h b/ortools/sat/lrat_proof_handler.h index 3de8ec68e3..8e4bea9106 100644 --- a/ortools/sat/lrat_proof_handler.h +++ b/ortools/sat/lrat_proof_handler.h @@ -33,6 +33,7 @@ #include "ortools/sat/recordio.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" namespace operations_research { namespace sat { @@ -149,6 +150,19 @@ class LratProofHandler { absl::Span unit_ids, absl::Span rat = {}); + // This assumes that the 'new_clause' to prove and all the ones needed for the + // proof only touch a small number of variables (<= 6). It will then prove the + // new clause by enumerating all possibilities and producing the relevant + // intermediate LRAT RUP steps. + // + // The last two arguments must have the same size and are in one to one + // correspondence. Note that we might not need all the given clauses in the + // proof. + bool AddAndProveInferredClauseByEnumeration( + ClauseId new_id, absl::Span new_clause, + absl::Span ids_for_proof, + const CompactVectorVector& clauses_for_proof); + // Adds a clause which was inferred by another worker. Returns true if // successful (the operation can fail if LRAT checks are enabled, and the ID // is already used by another clause). @@ -190,6 +204,7 @@ class LratProofHandler { bool LratError() const; const int id_; + ClauseIdGenerator* id_generator_; SharedLratProofStatus* proof_status_; std::unique_ptr lrat_checker_; std::unique_ptr lrat_writer_; diff --git a/ortools/sat/lrat_proof_handler_test.cc b/ortools/sat/lrat_proof_handler_test.cc new file mode 100644 index 0000000000..72dbad3852 --- /dev/null +++ b/ortools/sat/lrat_proof_handler_test.cc @@ -0,0 +1,76 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/lrat_proof_handler.h" + +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/util.h" + +namespace operations_research::sat { +namespace { + +TEST(AddAndProveInferredClauseByEnumerationTest, XorEquivalence) { + // We assume a = XOR(c,d) and b = XOR(c,d) and we want to prove a => b + const Literal a(+1); + const Literal b(+2); + const Literal c(+3); + const Literal d(+4); + + // Encode the two XOR gates. + CompactVectorVector clauses; + for (const Literal x : {a, b}) { + clauses.Add({c, d, x.Negated()}); + clauses.Add({c.Negated(), d, x}); + clauses.Add({c, d.Negated(), x}); + clauses.Add({c.Negated(), d.Negated(), x.Negated()}); + } + + // Create the lrat checker. + Model model; + auto* params = model.GetOrCreate(); + params->set_check_lrat_proof(true); + params->set_check_drat_proof(true); + std::unique_ptr lrat = + LratProofHandler::MaybeCreate(&model); + + // Lets create ids for all these clauses. + auto* id_generator = model.GetOrCreate(); + std::vector clause_ids; + for (int i = 0; i < clauses.size(); ++i) { + const ClauseId id = id_generator->GetNextId(); + clause_ids.push_back(id); + lrat->AddProblemClause(id, clauses[i]); + } + lrat->EndProblemClauses(); + + // This should be enough to prove equivalence. + { + std::vector to_prove = {b.Negated(), a}; + EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration( + id_generator->GetNextId(), to_prove, clause_ids, clauses)); + } + { + std::vector to_prove = {a.Negated(), b}; + EXPECT_TRUE(lrat->AddAndProveInferredClauseByEnumeration( + id_generator->GetNextId(), to_prove, clause_ids, clauses)); + } +} + +} // namespace +} // namespace operations_research::sat diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 1ea2bed4d3..051be3dd50 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -242,9 +242,17 @@ class AssignmentView { // The ID of a clause. DEFINE_STRONG_INT_TYPE(ClauseId, int64_t); - constexpr ClauseId kNoClauseId(0); +// A generator for ClauseIds. Not thread-safe. +class ClauseIdGenerator { + public: + ClauseId GetNextId() { return ClauseId(next_id_++); } + + private: + ClauseId next_id_ = ClauseId(1); +}; + // Forward declaration. class SatClause; class SatPropagator; @@ -346,12 +354,12 @@ class Trail { EnqueueHelper( Literal* trail_ptr, AssignmentInfo* current_info, AssignmentInfo* info_ptr, VariablesAssignment* assignment, - util_intops::StrongVector* unit_clause_id) + util_intops::StrongVector* clause_ids) : trail_ptr_(trail_ptr), current_info_(current_info), info_ptr_(info_ptr), bitset_(assignment->GetBitsetView()), - unit_clause_id_(unit_clause_id) {} + clause_ids_(clause_ids) {} void EnqueueAtLevel(Literal true_literal, int level) { bitset_.Set(true_literal); @@ -364,10 +372,10 @@ class Trail { void EnqueueWithUnitReason(Literal true_literal, ClauseId clause_id) { DCHECK_NE(clause_id, kNoClauseId); const BooleanVariable var = true_literal.Variable(); - if (var.value() >= unit_clause_id_->size()) { - unit_clause_id_->resize(var.value() + 1, kNoClauseId); + if (var.value() >= clause_ids_->size()) { + clause_ids_->resize(var.value() + 1, kNoClauseId); } - (*unit_clause_id_)[var] = clause_id; + (*clause_ids_)[var] = clause_id; bitset_.Set(true_literal); AssignmentInfo* info = info_ptr_ + true_literal.Variable().value(); @@ -389,12 +397,12 @@ class Trail { AssignmentInfo* current_info_; AssignmentInfo* info_ptr_; Bitset64::View bitset_; - util_intops::StrongVector* unit_clause_id_; + util_intops::StrongVector* clause_ids_; }; EnqueueHelper GetEnqueueHelper(int propagator_id) { current_info_.type = propagator_id; return EnqueueHelper(trail_.data(), ¤t_info_, info_.data(), - &assignment_, &unit_clause_id_); + &assignment_, &clause_ids_); } // Specific Enqueue() for search decisions. @@ -435,13 +443,7 @@ class Trail { // Specific Enqueue() version for unit clauses. void EnqueueWithUnitReason(ClauseId clause_id, Literal true_literal) { - if (clause_id != kNoClauseId) { - const BooleanVariable var = true_literal.Variable(); - if (var.value() >= unit_clause_id_.size()) { - unit_clause_id_.resize(var.value() + 1, kNoClauseId); - } - unit_clause_id_[var] = clause_id; - } + MaybeSetClauseId(true_literal, clause_id); EnqueueAtLevel(true_literal, AssignmentType::kUnitReason, 0); } @@ -459,18 +461,21 @@ class Trail { // Enqueues the given literal using the current content of // GetEmptyVectorToStoreReason() as the reason. This API is a bit more - // leanient and does not require the literal to be unassigned. If it is + // lenient and does not require the literal to be unassigned. If it is // already assigned to false, then MutableConflict() will be set appropriately // and this will return false otherwise this will enqueue the literal and // returns true. - ABSL_MUST_USE_RESULT bool EnqueueWithStoredReason(Literal true_literal) { + ABSL_MUST_USE_RESULT bool EnqueueWithStoredReason(ClauseId clause_id, + Literal true_literal) { if (assignment_.LiteralIsTrue(true_literal)) return true; if (assignment_.LiteralIsFalse(true_literal)) { *MutableConflict() = reasons_repository_[Index()]; MutableConflict()->push_back(true_literal); + failing_clause_id_ = clause_id; return false; } + MaybeSetClauseId(true_literal, clause_id); Enqueue(true_literal, AssignmentType::kCachedReason); const BooleanVariable var = true_literal.Variable(); reasons_[var] = reasons_repository_[info_[var].trail_index]; @@ -516,8 +521,17 @@ class Trail { ClauseId GetUnitClauseId(BooleanVariable var) const { DCHECK(AssignmentType(var) == AssignmentType::kUnitReason); DCHECK_EQ(Info(var).level, 0); - if (var.value() >= unit_clause_id_.size()) return kNoClauseId; - return unit_clause_id_[var]; + if (var.value() >= clause_ids_.size()) return kNoClauseId; + return clause_ids_[var]; + } + + // Returns the ID of the clause which is the reason why the given variable was + // enqueued, or kNoClauseId if there is none. The variable must have been + // enqueued with EnqueueWithStoredReason(). + ClauseId GetStoredReasonClauseId(BooleanVariable var) const { + DCHECK(AssignmentType(var) == AssignmentType::kCachedReason); + if (var.value() >= clause_ids_.size()) return kNoClauseId; + return clause_ids_[var]; } // If a variable was propagated with EnqueueWithSameReasonAs(), returns its @@ -583,6 +597,7 @@ class Trail { std::vector* MutableConflict() { ++conflict_timestamp_; failing_sat_clause_ = nullptr; + failing_clause_id_ = kNoClauseId; return &conflict_; } @@ -600,9 +615,16 @@ class Trail { // Specific SatClause interface so we can update the conflict clause activity. // Note that MutableConflict() automatically sets this to nullptr, so we can // know whether or not the last conflict was caused by a clause. - void SetFailingSatClause(SatClause* clause) { failing_sat_clause_ = clause; } + void SetFailingSatClause(SatClause* clause) { + failing_sat_clause_ = clause; + failing_clause_id_ = kNoClauseId; + } SatClause* FailingSatClause() const { return failing_sat_clause_; } + // Returns the LRAT ID of the failing clause. This ID is only set if a + // conflict is detected in EnqueueWithStoredReason(). + ClauseId FailingClauseId() const { return failing_clause_id_; } + // Getters. int NumVariables() const { return trail_.size(); } int64_t NumberOfEnqueues() const { return num_untrailed_enqueues_ + Index(); } @@ -664,6 +686,16 @@ class Trail { private: ConflictResolutionFunction resolution_; + void MaybeSetClauseId(Literal true_literal, ClauseId clause_id) { + if (clause_id != kNoClauseId) { + const BooleanVariable var = true_literal.Variable(); + if (var.value() >= clause_ids_.size()) { + clause_ids_.resize(var.value() + 1, kNoClauseId); + } + clause_ids_[var] = clause_id; + } + } + // Finds all literals between the current trail index and the given one // assigned at the current level or lower, and re-enqueues them with the same // reason. @@ -678,9 +710,11 @@ class Trail { int64_t conflict_timestamp_ = 0; std::vector conflict_; util_intops::StrongVector info_; - // The ID of unit clauses (literals enqueued at level 0 with a kUnitReason). - util_intops::StrongVector unit_clause_id_; + // The ID of unit clauses (literals enqueued at level 0 with a kUnitReason) + // and of reason clauses for literals enqueued with a stored reason. + util_intops::StrongVector clause_ids_; SatClause* failing_sat_clause_; + ClauseId failing_clause_id_; // Data used by EnqueueWithSameReasonAs(). util_intops::StrongVector diff --git a/ortools/sat/sat_inprocessing.cc b/ortools/sat/sat_inprocessing.cc index c41665b8fb..dae1ec3d5d 100644 --- a/ortools/sat/sat_inprocessing.cc +++ b/ortools/sat/sat_inprocessing.cc @@ -14,6 +14,7 @@ #include "ortools/sat/sat_inprocessing.h" #include +#include #include #include #include @@ -37,6 +38,7 @@ #include "ortools/base/timer.h" #include "ortools/graph/connected_components.h" #include "ortools/sat/clause.h" +#include "ortools/sat/gate_utils.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/probing.h" @@ -1916,14 +1918,46 @@ GateCongruenceClosure::~GateCongruenceClosure() { }); } +template +void GateCongruenceClosure::AddToTruthTable( + absl::Span clause, + absl::flat_hash_map, SmallBitset>& + data) { + CHECK_EQ(clause.size(), arity); + std::array key; + SmallBitset bitmask; + FillKeyAndBitmask(clause, absl::MakeSpan(key), bitmask); + for (const BooleanVariable var : key) { + CHECK(!implication_graph_->IsRemoved(Literal(var, true))); + } + auto [it, inserted] = data.insert({key, bitmask}); + if (!inserted) { + const SmallBitset old = it->second; + it->second &= bitmask; // Remove one value. + if (old != it->second) { + // TODO(user): keep id for proof. + } + } +} + // Note that this is the "hot" part of the algo, once we have the and gates, // the congruence closure should be quite fast. -void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) { +void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( + PresolveTimer& timer) { + truth_table3_.clear(); + truth_table4_.clear(); + std::vector candidates; for (SatClause* clause : clause_manager_->AllClausesInCreationOrder()) { if (timer.WorkLimitIsReached()) break; if (clause->size() == 0) continue; + if (clause->size() == 3) { + AddToTruthTable<3>(clause->AsSpan(), truth_table3_); + } else if (clause->size() == 4) { + AddToTruthTable<4>(clause->AsSpan(), truth_table4_); + } + // Used for an optimization below. int min_num_implications = std::numeric_limits::max(); Literal lit_with_less_implications; @@ -2030,6 +2064,8 @@ void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) { // Add the detected gate (its inputs are the negation of each clause // literal other than the target). gates_target_.push_back(target.Index()); + gates_type_.push_back(kAndGateType); + const int index = gates_inputs_.Add({}); for (const Literal l : clause->AsSpan()) { if (l == target) continue; @@ -2048,6 +2084,60 @@ void GateCongruenceClosure::ExtractAndGates(PresolveTimer& timer) { // a single base clause of size n will correspond to n and_gates ! } } + + timer.AddCounter("and_gates", gates_inputs_.size()); +} + +template +void GateCongruenceClosure::ExtractShortGates( + PresolveTimer& timer, + const absl::flat_hash_map, SmallBitset>& + data) { + // For a table on n variables, we look for function x = f(n - 1) variable. + const int num_bits = arity - 1; + + // TODO(user): This is non-deterministic order. We need to fix that or + // initially sort the queue of gates to process. + int num_functions = 0; + for (const auto [key, truth_table] : data) { + for (int i = 0; i < arity; ++i) { + if (!IsFunction(i, truth_table)) continue; + ++num_functions; + + gates_target_.push_back(Literal(key[i], true)); + gates_inputs_.Add({}); + for (int j = 0; j < arity; ++j) { + if (i != j) { + gates_inputs_.AppendToLastVector(Literal(key[j], true)); + } + } + + // Generate the function truth table as a type. + // We will canonicalize it further in the main loop. + unsigned int type = 0; + for (int p = 0; p < (1 << num_bits); ++p) { + // Expand from (arity - 1) bits to (arity) bits. + const int bigger_p = AddHoleAtPosition(i, p); + + if ((truth_table >> (bigger_p + (1 << i))) & 1) { + // target is 1 at this position. + type |= 1 << p; + DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function. + } else { + // Note that if there is no feasible assignment for a given p, first + // we could have learned a smaller clause, but also we don't really + // care what is the value of the target at that point p, so we use + // zero. + } + } + + gates_type_.push_back(type); + gates_clause_.push_back(nullptr); + } + } + + timer.AddCounter(absl::StrCat("table", arity), data.size()); + timer.AddCounter(absl::StrCat("fn", num_bits), num_functions); } namespace { @@ -2290,17 +2380,27 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { gates_target_.clear(); gates_inputs_.clear(); + gates_type_.clear(); gates_clause_.clear(); - // TODO(user): Extract more gates, and encode store their type. - ExtractAndGates(timer); - timer.AddCounter("and_gates", gates_inputs_.size()); + ExtractAndGatesAndFillShortTruthTables(timer); + + // TODO(user): We currently do not support this with lrat. Fix. + if (lrat_proof_handler_ == nullptr) { + ExtractShortGates<3>(timer, truth_table3_); + ExtractShortGates<4>(timer, truth_table4_); + } + + // All vector have the same size. + // Except gates_clause_ which is only filled if we need proof. + CHECK_EQ(gates_target_.size(), gates_type_.size()); + CHECK_EQ(gates_target_.size(), gates_inputs_.size()); // If two gates have the same type and the same inputs, their targets are // equivalent. We use an hash set to detect that the inputs are the same. absl::flat_hash_set gate_set( - /*capacity=*/gates_inputs_.size(), GateHash(&gates_inputs_), - GateEq(&gates_inputs_)); + /*capacity=*/gates_inputs_.size(), GateHash(&gates_type_, &gates_inputs_), + GateEq(&gates_type_, &gates_inputs_)); // Used to find representatives as we detect equivalent literal. DenseConnectedComponentsFinder union_find; @@ -2330,6 +2430,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { int num_units = 0; int num_processed = 0; int num_equivalences = 0; + int arity1_equivalences = 0; while (!queue.empty()) { ++num_processed; const int id = queue.back(); @@ -2368,13 +2469,13 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { if (is_equivalent) { CHECK_NE(id, other_id); CHECK_GE(other_id, 0); + CHECK_EQ(gates_type_[id], gates_type_[other_id]); CHECK_EQ(absl::Span(gates_inputs_[id]), absl::Span(gates_inputs_[other_id])); // We detected a <=> b (or, equivalently, rep(a) <=> rep(b)). const LiteralIndex a = gates_target_[id]; const LiteralIndex b = gates_target_[other_id]; - input_literals_to_gate.RemoveFromFutureOutput(id); if (lrat_proof_handler_ != nullptr) { lrat_helper.ShortenEquivalencesWithRepresentative(Literal(a)); @@ -2382,6 +2483,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { } const LiteralIndex rep_a(union_find.FindRoot(a.value())); const LiteralIndex rep_b(union_find.FindRoot(b.value())); + if (rep_a != rep_b) { ++num_equivalences; const Literal rep_lit_a(rep_a); @@ -2401,45 +2503,58 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { return false; } - union_find.AddEdge(rep_a.value(), rep_b.value()); - const LiteralIndex rep(union_find.FindRoot(rep_b.value())); - const LiteralIndex to_merge = rep == rep_a ? rep_b : rep_a; - input_literals_to_gate.MergeInto(to_merge, rep); - if (lrat_proof_handler_ != nullptr) { - if (rep == rep_a) { - lrat_helper.AddGateEquivalenceClauses( - rep_lit_b, rep_b_implies_rep_a, rep_a_implies_rep_b); - } else { - lrat_helper.AddGateEquivalenceClauses( - rep_lit_a, rep_a_implies_rep_b, rep_b_implies_rep_a); - } - } + for (const bool negate : {false, true}) { + const LiteralIndex x = + negate ? Literal(rep_a).NegatedIndex() : rep_a; + const LiteralIndex y = + negate ? Literal(rep_b).NegatedIndex() : rep_b; - // Re-add to the queue all gates with touched inputs. - // - // TODO(user): I think we could only add the gates of "to_merge" - // before we merge. This part of the code is quite quick in any case. - for (const int gate_id : input_literals_to_gate[rep]) { - if (in_queue[gate_id]) continue; - queue.push_back(gate_id); - in_queue[gate_id] = true; + union_find.AddEdge(x.value(), y.value()); + const LiteralIndex rep(union_find.FindRoot(y.value())); + const LiteralIndex to_merge = rep == x ? y : x; + input_literals_to_gate.MergeInto(to_merge, rep); + + if (lrat_proof_handler_ != nullptr) { + if (rep == x) { + lrat_helper.AddGateEquivalenceClauses( + Literal(y), + negate ? rep_a_implies_rep_b : rep_b_implies_rep_a, + negate ? rep_b_implies_rep_a : rep_a_implies_rep_b); + } else { + lrat_helper.AddGateEquivalenceClauses( + Literal(x), + negate ? rep_b_implies_rep_a : rep_a_implies_rep_b, + negate ? rep_a_implies_rep_b : rep_b_implies_rep_a); + } + } + + // Re-add to the queue all gates with touched inputs. + // + // TODO(user): I think we could only add the gates of "to_merge" + // before we merge. This part of the code is quite quick in any + // case. + for (const int gate_id : input_literals_to_gate[rep]) { + if (in_queue[gate_id]) continue; + queue.push_back(gate_id); + in_queue[gate_id] = true; + } } } break; } - // Canonicalize. + // Canonicalize on pass zero, otherwise loop. + // Note that even if nothing change, we want to run canonicalize at + // least once on the small "truth table" gates. // // Note that sorting works for and_gate and any gate that does not depend // on the order of its inputs. But if we add more fancy functions, we will // need to be careful. - // - // TODO(user): Because we fix literal, we should also deal with fixed - // literals here. Right now we defer this to a later step, but it might - // reduce the cascading effect of finding more equivalences. - if (pass == 0) { - marked_.ResetAllToFalse(); + if (pass > 0) continue; + + if (gates_type_[id] == kAndGateType) { absl::Span inputs = gates_inputs_[id]; + marked_.ResetAllToFalse(); int new_size = 0; bool is_unit = false; for (const LiteralIndex l : inputs) { @@ -2454,13 +2569,16 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { if (marked_[Literal(rep).Negated()]) { is_unit = true; const Literal to_fix = Literal(gates_target_[id]).Negated(); - absl::InlinedVector clause_ids; - if (lrat_proof_handler_ != nullptr) { - lrat_helper.AppendFixAndGateTargetClauses(id, Literal(rep), - clause_ids); - } - if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) { - return false; + if (!assignment_.LiteralIsTrue(to_fix)) { + absl::InlinedVector clause_ids; + if (lrat_proof_handler_ != nullptr) { + lrat_helper.AppendFixAndGateTargetClauses(id, Literal(rep), + clause_ids); + } + if (!clause_manager_->InprocessingFixLiteral(to_fix, + clause_ids)) { + return false; + } } break; } @@ -2480,6 +2598,50 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { CHECK_GT(new_size, 0); std::sort(inputs.begin(), inputs.begin() + new_size); gates_inputs_.Shrink(id, new_size); + + // Lets convert to the short "type" if we can. The truth table is simply + // a 1 on the last position (where all inputs are ones). We fall back to + // the case below to canonicalize further. + if (new_size > 4 || lrat_proof_handler_ != nullptr) continue; + gates_type_[id] = 1 << ((1 << new_size) - 1); + } + + // Generic "short" gates. + // We just take the representative and re-canonicalize. + absl::Span inputs = gates_inputs_[id]; + CHECK_GE(gates_type_[id], 0); + CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0); + for (LiteralIndex& lit_ref : inputs) { + lit_ref = LiteralIndex(union_find.FindRoot(lit_ref.value())); + } + + const int new_size = CanonicalizeFunctionTruthTable( + gates_target_[id], inputs, gates_type_[id]); + if (new_size < inputs.size()) { + gates_inputs_.Shrink(id, new_size); + } + + if (new_size == 1) { + // We have a function of size 1! this is an equivalence. + // + // TODO(user): deal with it. + ++arity1_equivalences; + input_literals_to_gate.RemoveFromFutureOutput(id); + break; + } else if (new_size == 0) { + // We have a fixed function ! just fix the literal. + CHECK(Literal(gates_target_[id]).IsPositive()); + const Literal to_fix{Literal(gates_target_[id]).Variable(), + (gates_type_[id] & 1) == 1}; + if (!assignment_.LiteralIsTrue(to_fix)) { + absl::InlinedVector clause_ids; + if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) { + return false; + } + ++num_units; + } + input_literals_to_gate.RemoveFromFutureOutput(id); + break; } } } @@ -2490,6 +2652,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { total_equivalences_ += num_equivalences; total_num_units_ += num_units; + timer.AddCounter("arity1_equivalences", arity1_equivalences); timer.AddCounter("units", num_units); timer.AddCounter("processed", num_processed); timer.AddCounter("equivalences", num_equivalences); diff --git a/ortools/sat/sat_inprocessing.h b/ortools/sat/sat_inprocessing.h index 74fab5d716..167960d8be 100644 --- a/ortools/sat/sat_inprocessing.h +++ b/ortools/sat/sat_inprocessing.h @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "ortools/base/strong_vector.h" #include "ortools/sat/clause.h" +#include "ortools/sat/gate_utils.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" @@ -439,7 +440,8 @@ class BoundedVariableElimination { class GateCongruenceClosure { public: explicit GateCongruenceClosure(Model* model) - : sat_solver_(model->GetOrCreate()), + : assignment_(model->GetOrCreate()->Assignment()), + sat_solver_(model->GetOrCreate()), implication_graph_(model->GetOrCreate()), clause_manager_(model->GetOrCreate()), clause_id_generator_(model->GetOrCreate()), @@ -454,23 +456,29 @@ class GateCongruenceClosure { private: struct GateHash { - explicit GateHash(CompactVectorVector* g) - : gates_inputs(g) {} + explicit GateHash(std::vector* t, + CompactVectorVector* g) + : gates_type(t), gates_inputs(g) {} std::size_t operator()(int gate_id) const { - return absl::HashOf((*gates_inputs)[gate_id]); + return absl::HashOf((*gates_type)[gate_id], (*gates_inputs)[gate_id]); } - CompactVectorVector* gates_inputs; + const std::vector* gates_type; + const CompactVectorVector* gates_inputs; }; struct GateEq { - explicit GateEq(CompactVectorVector* g) - : gates_inputs(g) {} + explicit GateEq(std::vector* t, + CompactVectorVector* g) + : gates_type(t), gates_inputs(g) {} bool operator()(int gate_a, int gate_b) const { if (gate_a == gate_b) return true; + // We use absl::span<> comparison. - return (*gates_inputs)[gate_a] == (*gates_inputs)[gate_b]; + return ((*gates_type)[gate_a] == (*gates_type)[gate_b]) && + ((*gates_inputs)[gate_a] == (*gates_inputs)[gate_b]); } - CompactVectorVector* gates_inputs; + const std::vector* gates_type; + const CompactVectorVector* gates_inputs; }; // Recovers "target_literal = and(literals)" from the model. @@ -483,8 +491,20 @@ class GateCongruenceClosure { // - for all i, target_literal => literal_i (direct binary implication) // - all literal at true => target_literal, this is a clause: // (not(literal[i]) for all i, target_literal). - void ExtractAndGates(PresolveTimer& timer); + void ExtractAndGatesAndFillShortTruthTables(PresolveTimer& timer); + // From possible assignment of "arity" given variables, extract functions. + template + void ExtractShortGates( + PresolveTimer& timer, + const absl::flat_hash_map, + SmallBitset>& data); + template + void AddToTruthTable(absl::Span clause, + absl::flat_hash_map, + SmallBitset>& data); + + const VariablesAssignment& assignment_; SatSolver* sat_solver_; BinaryImplicationGraph* implication_graph_; ClauseManager* clause_manager_; @@ -501,17 +521,33 @@ class GateCongruenceClosure { // A Boolean gates correspond to target = f(inputs). // // Note that the inputs are canonicalized. For and_gates, they are sorted, - // since the gate function does not depend on the order. + // since the gate function does not depend on the order. The type of an + // and_gates is kAndGateType. // - // TODO(user): for now we have a single gate type. We can easily support more - // by creating an extra std::vector and adding that to our - // GateHash/GateEq hash_set. + // Otherwise, we support generic 2 and 3 inputs gates where the type is the + // truth table. i.e. target = type[sum value_of_inputs[i] * 2^i]. For such + // gate, the target and inputs will always be canonicalized to positive and + // sorted literal. We just update the truth table accordingly. + static constexpr int kAndGateType = -1; std::vector gates_target_; + std::vector gates_type_; CompactVectorVector gates_inputs_; + // For each gate, "the" corresponding clause. For a gate a = and(x,y,...) this // is the clause "x and y and ... => a". Only used for LRAT. std::vector gates_clause_; + // Map (Xi) (sorted) to a bitmask corresponding to the allowed values. + // We loop over all short clauses to fill this. + // + // TODO(user): Shorter clauses impact larger truth table too and we can + // combine two size 3 to construct a size 4 (needed for ITE-gate). + // not ideal. + absl::flat_hash_map, SmallBitset> + truth_table3_; + absl::flat_hash_map, SmallBitset> + truth_table4_; + // For stats. double total_dtime_ = 0.0; double total_wtime_ = 0.0; diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 587d51a164..af07ac9147 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -127,10 +127,11 @@ message SatParameters { enum BinaryMinizationAlgorithm { reserved 2, 3, 4; NO_BINARY_MINIMIZATION = 0; - BINARY_MINIMIZATION_FIRST = 1; + BINARY_MINIMIZATION_FROM_UIP = 1; + BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS = 5; } optional BinaryMinizationAlgorithm binary_minimization_algorithm = 34 - [default = BINARY_MINIMIZATION_FIRST]; + [default = BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS]; // At a really low cost, during the 1-UIP conflict computation, it is easy to // detect if some of the involved reasons are subsumed by the current diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 78f4ac9e1b..7f65d7319e 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -974,10 +974,16 @@ void SatSolver::ProcessCurrentConflict( switch (parameters_->binary_minimization_algorithm()) { case SatParameters::NO_BINARY_MINIMIZATION: break; - case SatParameters::BINARY_MINIMIZATION_FIRST: + case SatParameters::BINARY_MINIMIZATION_FROM_UIP: binary_implication_graph_->MinimizeConflictFirst( *trail_, &learned_conflict_, &is_marked_, - clause_ids_for_minimization); + clause_ids_for_minimization, /*also_use_decisions=*/false); + break; + case SatParameters::BINARY_MINIMIZATION_FROM_UIP_AND_DECISIONS: + binary_implication_graph_->MinimizeConflictFirst( + *trail_, &learned_conflict_, &is_marked_, + clause_ids_for_minimization, /*also_use_decisions=*/true); + break; } DCHECK(IsConflictValid(learned_conflict_)); } @@ -1060,7 +1066,7 @@ void SatSolver::ProcessCurrentConflict( for (const Literal l : clause) { if (Assignment().LiteralIsFalse(l)) ++num_false; } - if (num_false == clause.size()) { + if (num_false == clause.size() || clause.size() == 1) { int max_level = 0; for (const Literal l : clause) { const int level = AssignmentLevel(l.Variable()); @@ -1366,6 +1372,8 @@ void SatSolver::AppendLratProofForFailingClause( const SatClause* failing_sat_clause = trail_->FailingSatClause(); if (failing_sat_clause != nullptr) { failing_clause_id = clauses_propagator_->GetClauseId(failing_sat_clause); + } else if (trail_->FailingClauseId() != kNoClauseId) { + failing_clause_id = trail_->FailingClauseId(); } else { absl::Span failing_clause = trail_->FailingClause(); if (failing_clause.size() == 2) { @@ -2342,6 +2350,7 @@ std::string SatSolver::RunningStatisticsString() const { void SatSolver::ProcessNewlyFixedVariables() { SCOPED_TIME_STAT(&stats_); DCHECK_EQ(CurrentDecisionLevel(), 0); + if (num_processed_fixed_variables_ == trail_->Index()) return; int num_detached_clauses = 0; int num_binary = 0; diff --git a/ortools/sat/stat_tables.cc b/ortools/sat/stat_tables.cc index 626a13351d..3bf7808d03 100644 --- a/ortools/sat/stat_tables.cc +++ b/ortools/sat/stat_tables.cc @@ -91,17 +91,19 @@ void SharedStatTables::AddSearchStat(absl::string_view name, Model* model) { void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) { absl::MutexLock mutex_lock(mutex_); auto* sat_solver = model->GetOrCreate(); + auto* binary = model->GetOrCreate(); SatSolver::Counters counters = sat_solver->counters(); if (clauses_table_.empty()) { clauses_table_.push_back({"SAT stats", "ClassicMinim", "LitRemoved", - "LitLearned", "LitForgotten", "Subsumed", - "MClauses", "MDecisions", "MLitTrue", "MSubsumed", - "MLitRemoved", "MReused"}); + "LitRemovedBinary", "LitLearned", "LitForgotten", + "Subsumed", "MClauses", "MDecisions", "MLitTrue", + "MSubsumed", "MLitRemoved", "MReused"}); } clauses_table_.push_back( {FormatName(name), FormatCounter(counters.num_minimizations), FormatCounter(counters.num_literals_removed), + FormatCounter(binary->num_literals_removed()), FormatCounter(counters.num_literals_learned), FormatCounter(counters.num_literals_forgotten), FormatCounter(counters.num_subsumed_clauses), @@ -117,7 +119,6 @@ void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) { bool_var_table_.push_back( {"Boolean variables", "Fixed", "Equiv", "Total", "Left", "Binary"}); } - auto* binary = model->GetOrCreate(); const int64_t num_fixed = sat_solver->NumFixedVariables(); const int64_t num_equiv = binary->num_redundant_literals() / 2; const int64_t num_bools = sat_solver->NumVariables(); diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index 44bb17a0cf..00a7c9a0fe 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -607,33 +607,6 @@ void MaxBoundedSubsetSum::AddChoicesInternal(absl::Span values) { } } -int64_t MaxBoundedSubsetSum::MaxIfAdded(int64_t candidate) const { - if (candidate > bound_ || current_max_ == bound_) return current_max_; - - int64_t current_max = current_max_; - // Mode 1: vector of all possible sums (with duplicates). - if (!sums_.empty()) { - for (const int64_t v : sums_) { - if (v + candidate > bound_) continue; - if (v + candidate > current_max) { - current_max = v + candidate; - if (current_max == bound_) return current_max; - } - } - return current_max; - } - - // Mode 2: bitset of all possible sums. - if (!expanded_sums_.empty()) { - const int64_t min_useful = std::max(0, current_max_ - candidate); - const int64_t max_useful = bound_ - candidate; - for (int64_t v = max_useful; v >= min_useful; --v) { - if (expanded_sums_[v]) return v + candidate; - } - } - return current_max_; -} - BasicKnapsackSolver::Result BasicKnapsackSolver::Solve( absl::Span domains, absl::Span coeffs, absl::Span costs, const Domain& rhs) { diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 87cf217f9a..d9a2c3ff7b 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -550,9 +550,6 @@ class MaxBoundedSubsetSum { // We look for the maximum sum <= bound. void Reset(int64_t bound); - // Returns the updated max if value was added to the subset-sum. - int64_t MaxIfAdded(int64_t candidate) const; - // Add a value to the base set for which subset sums will be taken. void Add(int64_t value); diff --git a/ortools/sat/util_test.cc b/ortools/sat/util_test.cc index b4a4f80b55..19ac63aa37 100644 --- a/ortools/sat/util_test.cc +++ b/ortools/sat/util_test.cc @@ -737,21 +737,6 @@ TEST(MaxBoundedSubsetSumTest, SimpleMultiChoice) { EXPECT_EQ(bounded_subset_sum.CurrentMax(), 31); } -TEST(MaxBoundedSubsetSumTest, CheckMaxIfAdded) { - MaxBoundedSubsetSum bounded_subset_sum(34); - bounded_subset_sum.Add(10); - bounded_subset_sum.Add(10); - bounded_subset_sum.Add(10); - EXPECT_EQ(bounded_subset_sum.MaxIfAdded(12), 32); - EXPECT_EQ(bounded_subset_sum.MaxIfAdded(15), 30); - EXPECT_EQ(bounded_subset_sum.MaxIfAdded(34), 34); - for (int i = 0; i < 100; ++i) { - bounded_subset_sum.Add(18); - } - EXPECT_EQ(bounded_subset_sum.CurrentMax(), 30); - EXPECT_EQ(bounded_subset_sum.MaxIfAdded(5), 33); -} - static void BM_bounded_subset_sum(benchmark::State& state) { random_engine_t random_; const int num_items = state.range(0); diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index 8241449c64..e7e9e1279f 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -635,18 +635,21 @@ const std::vector& SharedTreeWorker::DecisionReason(int level) { } bool SharedTreeWorker::AddDecisionImplication(Literal lit, int level) { + CHECK_GT(level, 0); CHECK_NE(lit.Index(), kNoLiteralIndex); CHECK(!sat_solver_->Assignment().LiteralIsTrue(lit)); + absl::Span reason = DecisionReason(level); if (sat_solver_->Assignment().LiteralIsFalse(lit)) { VLOG(2) << "Closing subtree via impl at " << level + 1 << " assigned=" << assigned_tree_.MaxLevel(); - integer_trail_->ReportConflict(DecisionReason(level), {}); + trail_->MutableConflict()->assign(reason.begin(), reason.end()); manager_->CloseTree(assigned_tree_, level); assigned_tree_literals_.clear(); return false; } VLOG(2) << "Learned shared clause"; - return integer_trail_->EnqueueLiteral(lit, DecisionReason(level), {}); + trail_->GetEmptyVectorToStoreReason()->assign(reason.begin(), reason.end()); + return trail_->EnqueueWithStoredReason(kNoClauseId, lit); } bool SharedTreeWorker::AddImplications() {