[CP-SAT] bug fixes, memory utilization reduction; more work on lrat

This commit is contained in:
Laurent Perron
2025-12-09 16:17:49 +01:00
committed by Mizux Seiha
parent 469d83a2ef
commit 4a2de332ce
9 changed files with 410 additions and 127 deletions

View File

@@ -26,6 +26,11 @@ namespace operations_research::sat {
using SmallBitset = uint32_t;
// This works for num_bits == 32 too.
inline SmallBitset GetNumBitsAtOne(int num_bits) {
return ~SmallBitset(0) >> (32 - (1 << num_bits));
}
// Sort the key and modify the truth table accordingly.
//
// Note that we don't deal with identical key here, but the function
@@ -35,7 +40,6 @@ template <typename VarOrLiteral>
void CanonicalizeTruthTable(absl::Span<VarOrLiteral> key,
SmallBitset& bitmask) {
const int num_bits = key.size();
CHECK_EQ(bitmask >> (1 << num_bits), 0);
for (int i = 0; i < num_bits; ++i) {
for (int j = i + 1; j < num_bits; ++j) {
if (key[i] <= key[j]) continue;
@@ -53,11 +57,9 @@ void CanonicalizeTruthTable(absl::Span<VarOrLiteral> key,
new_bitmask |= ((bitmask >> p) & 1) << new_p;
}
bitmask = new_bitmask;
CHECK_EQ(bitmask >> (1 << num_bits), 0)
<< i << " " << j << " " << num_bits;
}
}
CHECK(std::is_sorted(key.begin(), key.end()));
DCHECK(std::is_sorted(key.begin(), key.end()));
}
// Given a clause, return the truth table corresponding to it.
@@ -67,15 +69,14 @@ inline void FillKeyAndBitmask(absl::Span<const Literal> clause,
SmallBitset& bitmask) {
CHECK_EQ(clause.size(), key.size());
const int num_bits = clause.size();
bitmask = ~SmallBitset(0) >> (32 - (1 << num_bits)); // All allowed;
CHECK_EQ(bitmask >> (1 << num_bits), 0) << num_bits;
SmallBitset bit_to_remove = 0;
for (int i = 0; i < num_bits; ++i) {
key[i] = clause[i].Variable();
bit_to_remove |= (clause[i].IsPositive() ? 0 : 1) << i;
}
bitmask ^= (SmallBitset(1) << bit_to_remove);
CHECK_EQ(bitmask >> (1 << num_bits), 0) << bit_to_remove << " " << num_bits;
CHECK_LT(bit_to_remove, (1 << num_bits));
bitmask = GetNumBitsAtOne(num_bits);
bitmask ^= SmallBitset(1) << bit_to_remove;
CanonicalizeTruthTable<BooleanVariable>(key, bitmask);
}
@@ -111,15 +112,6 @@ inline int CanonicalizeFunctionTruthTable(LiteralIndex& target,
const int num_bits = inputs.size();
CHECK_LE(num_bits, 4); // Truth table must fit on an int.
const SmallBitset all_one = (1 << (1 << num_bits)) - 1;
CHECK_EQ(function_values & ~all_one, 0);
// Make sure target is positive.
if (!Literal(target).IsPositive()) {
target = Literal(target).Negated();
function_values = function_values ^ all_one;
CHECK_EQ(function_values >> (1 << num_bits), 0);
}
// Make sure all inputs are positive.
for (int i = 0; i < num_bits; ++i) {
@@ -189,6 +181,17 @@ inline int CanonicalizeFunctionTruthTable(LiteralIndex& target,
}
}
// If we have x = f(a,b,c) and not(y) = f(a,b,c) with the same f, we have an
// equivalence, so we need to canonicallize both f() and not(f()) to the same
// function. For that we just always choose to have the lowest bit at zero.
if (function_values & 1) {
target = Literal(target).Negated();
const SmallBitset all_one = GetNumBitsAtOne(inputs.size());
function_values = function_values ^ all_one;
DCHECK_EQ(function_values >> (1 << inputs.size()), 0);
}
DCHECK_EQ(function_values & 1, 0);
int_function_values = function_values;
return inputs.size();
}

View File

@@ -1667,10 +1667,6 @@ ReasonIndex IntegerTrail::AppendReasonToInternalBuffers(
return reason_index;
}
int64_t IntegerTrail::NextConflictId() {
return sat_solver_->num_failures() + 1;
}
bool IntegerTrail::EnqueueInternal(
IntegerLiteral i_lit, bool use_lazy_reason,
absl::Span<const Literal> literal_reason,
@@ -1721,7 +1717,8 @@ bool IntegerTrail::EnqueueInternal(
return false;
}
MergeReasonIntoInternal(conflict, NextConflictId());
// TODO(user): fix or remove the conflict_id optimization.
MergeReasonIntoInternal(conflict, -1);
return false;
}
@@ -1793,7 +1790,8 @@ bool IntegerTrail::EnqueueInternal(
return false;
}
MergeReasonIntoInternal(conflict, NextConflictId());
// TODO(user): fix or remove the conflict_id optimization.
MergeReasonIntoInternal(conflict, -1);
return false;
}
@@ -2348,7 +2346,10 @@ absl::Span<const Literal> IntegerTrail::Reason(
tmp_queue_.push_back(prev_trail_index);
}
// TODO(user): fix or remove the conflict_id optimization.
// It is probably superseded by the new integer resolution in any case,
// so I fill we could just remove it.
MergeReasonIntoInternal(reason, -1);
if (DEBUG_MODE && debug_checker_ != nullptr) {

View File

@@ -1090,9 +1090,6 @@ class IntegerTrail final : public SatPropagator {
// Returns some debugging info.
std::string DebugString();
// Used internally to return the next conflict number.
int64_t NextConflictId();
// Information for each integer variable about its current lower bound and
// position of the last TrailEntry in the trail referring to this var.
util_intops::StrongVector<IntegerVariable, IntegerValue> var_lbs_;

View File

@@ -122,7 +122,10 @@ void MinimizeCoreWithSearch(TimeLimit* limit, SatSolver* solver,
const int old_size = core->size();
std::vector<Literal> assumptions;
absl::flat_hash_set<LiteralIndex> removed_once;
while (true) {
// We stop as soon as the core size is one since there is nothing more
// to minimize then.
while (core->size() > 1) {
if (limit->LimitReached()) break;
// Find a not yet removed literal to remove.

View File

@@ -184,6 +184,46 @@ _IndexOrSeries = Union[pd.Index, pd.Series]
# Helper functions.
# warnings.deprecated is python3.13+. Not compatible with Open Source (3.10+).
# pylint: disable=g-bare-generic
def deprecated(message: str) -> Callable[[Callable], Callable]:
"""Decorator that warns about a deprecated function."""
def deprecated_decorator(func) -> Callable:
def deprecated_func(*args, **kwargs):
warnings.warn(
f"{func.__name__} is a deprecated function. {message}",
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func(*args, **kwargs)
return deprecated_func
return deprecated_decorator
def deprecated_method(func, old_name: str) -> Callable:
"""Wrapper that warns about a deprecated method."""
def deprecated_func(*args, **kwargs) -> Any:
warnings.warn(
f"{old_name} is a deprecated function. Use {func.__name__} instead.",
category=DeprecationWarning,
stacklevel=2,
)
warnings.simplefilter("default", DeprecationWarning)
return func(*args, **kwargs)
return deprecated_func
# pylint: enable=g-bare-generic
def snake_case_to_camel_case(name: str) -> str:
"""Converts a snake_case name to CamelCase."""
words = name.split("_")
@@ -274,16 +314,6 @@ class CpModel(cmh.CpBaseModel):
cmh.CpBaseModel.__init__(self, model_proto)
self._add_pre_pep8_methods()
def _add_pre_pep8_methods(self) -> None:
for method_name in dir(self):
if callable(getattr(self, method_name)) and (
method_name.startswith("add_")
or method_name.startswith("new_")
or method_name.startswith("clear_")
):
pre_pep8_name = snake_case_to_camel_case(method_name)
setattr(self, pre_pep8_name, getattr(self, method_name))
# Naming.
@property
def name(self) -> str:
@@ -1649,30 +1679,52 @@ class CpModel(cmh.CpBaseModel):
# Compatibility with pre PEP8
# pylint: disable=invalid-name
def _add_pre_pep8_methods(self) -> None:
for method_name in dir(self):
if callable(getattr(self, method_name)) and (
method_name.startswith("add_")
or method_name.startswith("new_")
or method_name.startswith("clear_")
):
pre_pep8_name = snake_case_to_camel_case(method_name)
setattr(
self,
pre_pep8_name,
deprecated_method(getattr(self, method_name), pre_pep8_name),
)
for other_method_name in [
"add",
"clone",
"get_bool_var_from_proto_index",
"get_int_var_from_proto_index",
"get_interval_var_from_proto_index",
"minimize",
"maximize",
"has_objective",
"model_stats",
"validate",
"export_to_file",
]:
pre_pep8_name = snake_case_to_camel_case(other_method_name)
setattr(
self,
pre_pep8_name,
deprecated_method(getattr(self, other_method_name), pre_pep8_name),
)
@deprecated("Use name property instead.")
def Name(self) -> str:
return self.name
@deprecated("Use name property instead.")
def SetName(self, name: str) -> None:
self.name = name
@deprecated("Use proto property instead.")
def Proto(self) -> cmh.CpModelProto:
return self.proto
Add = add
Clone = clone
GetBoolVarFromProtoIndex = get_bool_var_from_proto_index
GetIntVarFromProtoIndex = get_int_var_from_proto_index
GetIntervalVarFromProtoIndex = get_interval_var_from_proto_index
Minimize = minimize
Maximize = maximize
HasObjective = has_objective
ModelStats = model_stats
Validate = validate
ExportToFile = export_to_file
# add_XXX, new_XXX, and clear_XXX methods are already duplicated
# automatically.
# pylint: enable=invalid-name
@@ -1924,49 +1976,85 @@ class CpSolver:
# Compatibility with pre PEP8
# pylint: disable=invalid-name
@deprecated("Use best_objective_bound property instead.")
def BestObjectiveBound(self) -> float:
return self.best_objective_bound
BooleanValue = boolean_value
BooleanValues = boolean_values
@deprecated("Use boolean_value() method instead.")
def BooleanValue(self, lit: LiteralT) -> bool:
return self.boolean_value(lit)
@deprecated("Use boolean_values() method instead.")
def BooleanValues(self, variables: _IndexOrSeries) -> pd.Series:
return self.boolean_values(variables)
@deprecated("Use num_booleans property instead.")
def NumBooleans(self) -> int:
return self.num_booleans
@deprecated("Use num_conflicts property instead.")
def NumConflicts(self) -> int:
return self.num_conflicts
@deprecated("Use num_branches property instead.")
def NumBranches(self) -> int:
return self.num_branches
@deprecated("Use objective_value property instead.")
def ObjectiveValue(self) -> float:
return self.objective_value
@deprecated("Use response_proto property instead.")
def ResponseProto(self) -> cmh.CpSolverResponse:
return self.response_proto
ResponseStats = response_stats
Solve = solve
SolutionInfo = solution_info
StatusName = status_name
StopSearch = stop_search
SufficientAssumptionsForInfeasibility = sufficient_assumptions_for_infeasibility
@deprecated("Use response_stats() method instead.")
def ResponseStats(self) -> str:
return self.response_stats()
@deprecated("Use solve() method instead.")
def Solve(
self, model: CpModel, callback: "CpSolverSolutionCallback" = None
) -> cmh.CpSolverStatus:
return self.solve(model, callback)
@deprecated("Use solution_info() method instead.")
def SolutionInfo(self) -> str:
return self.solution_info()
@deprecated("Use status_name() method instead.")
def StatusName(self, status: Optional[Any] = None) -> str:
return self.status_name(status)
@deprecated("Use stop_search() method instead.")
def StopSearch(self) -> None:
self.stop_search()
@deprecated("Use sufficient_assumptions_for_infeasibility() method instead.")
def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]:
return self.sufficient_assumptions_for_infeasibility()
@deprecated("Use user_time property instead.")
def UserTime(self) -> float:
return self.user_time
Value = value
Values = values
@deprecated("Use value() method instead.")
def Value(self, expression: LinearExprT) -> int:
return self.value(expression)
@deprecated("Use values() method instead.")
def Values(self, expressions: _IndexOrSeries) -> pd.Series:
return self.values(expressions)
@deprecated("Use wall_time property instead.")
def WallTime(self) -> float:
return self.wall_time
@deprecated("Use solve() with enumerate_all_solutions = True.")
def SearchForAllSolutions(
self, model: CpModel, callback: "CpSolverSolutionCallback"
) -> cmh.CpSolverStatus:
"""DEPRECATED Use solve() with the right parameter.
Search for all solutions of a satisfiability problem.
"""Search for all solutions of a satisfiability problem.
This method searches for all feasible solutions of a given model.
Then it feeds the solution to the callback.
@@ -1984,11 +2072,6 @@ class CpSolver:
* *INFEASIBLE* if the solver has proved there are no solution
* *OPTIMAL* if all solutions have been found
"""
warnings.warn(
"search_for_all_solutions is deprecated; use solve() with"
+ "enumerate_all_solutions = True.",
DeprecationWarning,
)
if model.has_objective():
raise TypeError(
"Search for all solutions is only defined on satisfiability problems"

View File

@@ -2767,10 +2767,13 @@ TRFM"""
model = cp_model.CpModel()
x = [model.NewBoolVar(f"x{i}") for i in range(5)]
model.AddBoolOr(x)
model.Maximize(sum(x))
self.assertLen(model.proto.variables, 5)
self.assertLen(model.proto.constraints, 1)
self.assertLen(model.proto.constraints[0].bool_or.literals, 5)
self.assertTrue(hasattr(model, "Proto"))
model_copy = copy.copy(model)
self.assertTrue(hasattr(model_copy, "AddBoolOr"))
self.assertTrue(hasattr(model_copy, "AddBoolXOr"))

View File

@@ -89,6 +89,7 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) {
// Mainly useful for development.
double probing_time = 0.0;
const bool log_info = VLOG_IS_ON(2);
const bool log_round_info = VLOG_IS_ON(2);
// We currently do the transformations in a given order and restart each time
@@ -140,6 +141,13 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) {
continue;
}
// TODO(user): Think about the right order in this function.
if (params_.inprocessing_use_congruence_closure()) {
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables());
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_info));
}
// TODO(user): Combine the two? this way we don't create a full literal <->
// clause graph twice. It might make sense to reach the BCE fix point which
// is unique before each variable elimination.
@@ -147,13 +155,6 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) {
blocked_clause_simplifier_->DoOneRound(log_round_info);
}
// TODO(user): Think about the right order in this function.
if (params_.inprocessing_use_congruence_closure()) {
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables());
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_round_info));
}
// TODO(user): this break some binary graph invariant. Fix!
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
RETURN_IF_FALSE(bounded_variable_elimination_->DoOneRound(log_round_info));
@@ -298,7 +299,7 @@ bool Inprocessing::InprocessingRound() {
if (params_.inprocessing_use_congruence_closure()) {
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables());
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_round_info));
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_info));
}
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
@@ -1956,6 +1957,7 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
PresolveTimer& timer) {
ids3_.clear();
ids4_.clear();
ids5_.clear();
truth_tables_inputs_.clear();
truth_tables_bitset_.clear();
truth_tables_clauses_.clear();
@@ -1971,6 +1973,8 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
AddToTruthTable<3>(clause, ids3_);
} else if (clause->size() == 4) {
AddToTruthTable<4>(clause, ids4_);
} else if (clause->size() == 5) {
AddToTruthTable<5>(clause, ids5_);
}
// Used for an optimization below.
@@ -2103,68 +2107,235 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
timer.AddCounter("and_gates", gates_inputs_.size());
}
int GateCongruenceClosure::ProcessTruthTable(
absl::Span<const BooleanVariable> inputs, SmallBitset truth_table,
absl::Span<const TruthTableId> ids_for_proof) {
int num_detected = 0;
for (int i = 0; i < inputs.size(); ++i) {
if (!IsFunction(i, inputs.size(), truth_table)) continue;
const int num_bits = inputs.size() - 1;
++num_detected;
gates_target_.push_back(Literal(inputs[i], true));
gates_inputs_.Add({});
for (int j = 0; j < inputs.size(); ++j) {
if (i != j) {
gates_inputs_.AppendToLastVector(Literal(inputs[j], true));
}
}
// Generate the function truth table as a type.
// We will canonicalize it further in the main loop.
unsigned int type = 0;
for (int p = 0; p < (1 << num_bits); ++p) {
// Expand from (num_bits == inputs.size() - 1) bits to (inputs.size())
// bits.
const int bigger_p = AddHoleAtPosition(i, p);
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
// target is 1 at this position.
type |= 1 << p;
DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function.
} else {
// Note that if there is no feasible assignment for a given p if
// [(truth_table >> bigger_p) & 1] is not 1.
//
// Here we could learn smaller clause, but also we don't really care
// what is the value of the target at that point p, so we use zero.
//
// TODO(user): This is not ideal if another function has a value of 1 at
// point p, we could still merge it with this one. Shall we create two
// gate type? or change the algo?
}
}
gates_type_.push_back(type);
if (lrat_proof_handler_ != nullptr) {
gates_clauses_.Add({});
for (const TruthTableId id : ids_for_proof) {
gates_clauses_.AppendToLastVector(truth_tables_clauses_[id]);
}
}
}
return num_detected;
}
namespace {
// Return a BooleanVariable from b that is not in a or kNoBooleanVariable;
BooleanVariable FindMissing(absl::Span<const BooleanVariable> vars_a,
absl::Span<const BooleanVariable> vars_b) {
for (const BooleanVariable b : vars_b) {
bool found = false;
for (const BooleanVariable a : vars_a) {
if (a == b) {
found = true;
break;
}
}
if (!found) return b;
}
return kNoBooleanVariable;
}
} // namespace
// TODO(user): It should be possible to extract ALL possible short gate, but
// we are not there yet.
void GateCongruenceClosure::ExtractShortGates(PresolveTimer& timer) {
if (lrat_proof_handler_ != nullptr) {
truth_tables_clauses_.ResetFromFlatMapping(tmp_ids_, tmp_clauses_);
truth_tables_clauses_.ResetFromFlatMapping(
tmp_ids_, tmp_clauses_,
/*minimum_num_nodes=*/truth_tables_bitset_.size());
CHECK_EQ(truth_tables_bitset_.size(), truth_tables_clauses_.size());
}
// This is used to combine two 3 arity table into one 4 arity one if
// they share two variables.
absl::flat_hash_map<std::array<BooleanVariable, 2>, int> binary_index_map;
std::vector<int> flat_binary_index;
std::vector<TruthTableId> flat_table_id;
// Counters.
// We only fill a subset of the entries, but that makes the code shorter.
std::vector<int> num_tables(5);
std::vector<int> num_functions(5);
// We only fill a subset of the entries, but that make the code shorter.
std::vector<int> num_tables(6);
std::vector<int> num_functions(6);
// Note that using the indirection via TruthTableId allow this code to
// be deterministic.
CHECK_EQ(truth_tables_bitset_.size(), truth_tables_inputs_.size());
for (TruthTableId id(0); id < truth_tables_inputs_.size(); ++id) {
const absl::Span<const BooleanVariable> inputs = truth_tables_inputs_[id];
const SmallBitset truth_table = truth_tables_bitset_[id];
++num_tables[inputs.size()];
for (int i = 0; i < inputs.size(); ++i) {
if (!IsFunction(i, inputs.size(), truth_table)) continue;
const int num_bits = inputs.size() - 1;
++num_functions[num_bits];
gates_target_.push_back(Literal(inputs[i], true));
gates_inputs_.Add({});
for (int j = 0; j < inputs.size(); ++j) {
if (i != j) {
gates_inputs_.AppendToLastVector(Literal(inputs[j], true));
}
// Given a table of arity 4, this merges all the information from the tables
// of arity 3 included in it.
int num_merges = 0;
const auto merge3_into_4 = [this, &num_merges](
absl::Span<const BooleanVariable> inputs,
SmallBitset& truth_table,
std::vector<TruthTableId>& ids_for_proof) {
DCHECK_EQ(inputs.size(), 4);
for (int i_to_remove = 0; i_to_remove < inputs.size(); ++i_to_remove) {
int pos = 0;
std::array<BooleanVariable, 3> key3;
for (int i = 0; i < inputs.size(); ++i) {
if (i == i_to_remove) continue;
key3[pos++] = inputs[i];
}
const auto it = ids3_.find(key3);
if (it == ids3_.end()) continue;
++num_merges;
// Generate the function truth table as a type.
// We will canonicalize it further in the main loop.
unsigned int type = 0;
for (int p = 0; p < (1 << num_bits); ++p) {
// Expand from (num_bits == inputs.size() - 1) bits to (inputs.size())
// bits.
const int bigger_p = AddHoleAtPosition(i, p);
// Extend the bitset from the table3 so that it is expressed correctly
// for the given inputs.
const TruthTableId id3 = it->second;
std::array<BooleanVariable, 4> key4;
for (int i = 0; i < 3; ++i) key4[i] = key3[i];
key4[3] = FindMissing(key3, inputs);
SmallBitset bitset = truth_tables_bitset_[id3];
bitset |= bitset << (1 << 3); // Extend for a new variable.
CanonicalizeTruthTable<BooleanVariable>(absl::MakeSpan(key4), bitset);
CHECK_EQ(inputs, absl::MakeSpan(key4));
truth_table &= bitset;
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
// target is 1 at this position.
type |= 1 << p;
DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function.
} else {
// Note that if there is no feasible assignment for a given p, first
// we could have learned a smaller clause, but also we don't really
// care what is the value of the target at that point p, so we use
// zero.
}
}
gates_type_.push_back(type);
// We need to add the corresponding clause!
if (lrat_proof_handler_ != nullptr) {
gates_clauses_.Add(truth_tables_clauses_[id]);
ids_for_proof.push_back(id3);
}
}
};
// Starts by processing all existing tables.
//
// TODO(user): Since we deal with and_gates differently, do we need to look at
// binary clauses here ? or that is not needed ? I think there is only two
// kind of Boolean function on two inputs (and_gates, with any possible
// negation) and xor_gate.
std::vector<TruthTableId> ids_for_proof;
for (TruthTableId t_id(0); t_id < truth_tables_inputs_.size(); ++t_id) {
ids_for_proof.clear();
ids_for_proof.push_back(t_id);
const absl::Span<const BooleanVariable> inputs = truth_tables_inputs_[t_id];
SmallBitset truth_table = truth_tables_bitset_[t_id];
// Merge any size-3 table included inside a size-4 gate.
// TODO(user): do it for larger gate too ?
if (inputs.size() == 4) {
merge3_into_4(inputs, truth_table, ids_for_proof);
}
++num_tables[inputs.size()];
const int num_detected =
ProcessTruthTable(inputs, truth_table, ids_for_proof);
num_functions[inputs.size() - 1] += num_detected;
// If this is not a function and of size 3, lets try to combine it with
// another truth table of size 3 to get a table of size 4.
if (inputs.size() == 3 && num_detected == 0) {
for (int i = 0; i < 3; ++i) {
std::array<BooleanVariable, 2> key{inputs[i != 0 ? 0 : 1],
inputs[i != 2 ? 2 : 1]};
DCHECK(std::is_sorted(key.begin(), key.end()));
const auto [it, inserted] =
binary_index_map.insert({key, binary_index_map.size()});
flat_binary_index.push_back(it->second);
flat_table_id.push_back(t_id);
}
}
}
gtl::STLClearObject(&binary_index_map);
// Detects ITE gates and potentially other 3-gates from a truth table of
// 4-entries formed by two 3-entries table. This just create a 4-entries
// table that will be processed below.
CompactVectorVector<int, TruthTableId> candidates;
candidates.ResetFromFlatMapping(flat_binary_index, flat_table_id);
gtl::STLClearObject(&flat_binary_index);
gtl::STLClearObject(&flat_table_id);
int num_combinations = 0;
for (int c = 0; c < candidates.size(); ++c) {
if (candidates[c].size() < 2) continue;
if (candidates[c].size() > 10) continue; // Too many? use heuristic.
for (int a = 0; a < candidates[c].size(); ++a) {
for (int b = a + 1; b < candidates[c].size(); ++b) {
const absl::Span<const BooleanVariable> inputs_a =
truth_tables_inputs_[candidates[c][a]];
const absl::Span<const BooleanVariable> inputs_b =
truth_tables_inputs_[candidates[c][b]];
std::array<BooleanVariable, 4> key;
for (int i = 0; i < 3; ++i) key[i] = inputs_a[i];
key[3] = FindMissing(inputs_a, inputs_b);
CHECK_NE(key[3], kNoBooleanVariable);
// Add an all allowed entry.
SmallBitset bitmask = GetNumBitsAtOne(4);
CanonicalizeTruthTable<BooleanVariable>(absl::MakeSpan(key), bitmask);
// If the key was not processed before, process it now.
// Note that an old version created a TruthTableId for it, but that
// waste a lot of space.
//
// On another hand, it is possible we process the same key up to
// 4_choose_2 times, but this is rare...
if (!ids4_.contains(key)) {
++num_combinations;
++num_tables[4];
ids_for_proof.clear();
merge3_into_4(key, bitmask, ids_for_proof);
num_functions[3] += ProcessTruthTable(key, bitmask, ids_for_proof);
}
}
}
}
timer.AddCounter("combine3", num_combinations);
timer.AddCounter("merges", num_merges);
// Note that we only display non-zero counters.
for (int i = 2; i < 5; ++i) {
timer.AddCounter(absl::StrCat("table", i), num_tables[i]);
for (int i = 2; i < num_tables.size(); ++i) {
timer.AddCounter(absl::StrCat("t", i), num_tables[i]);
}
for (int i = 2; i < num_functions.size(); ++i) {
timer.AddCounter(absl::StrCat("fn", i), num_functions[i]);
}
}
@@ -2583,8 +2754,9 @@ class LratGateCongruenceHelper {
} // namespace
bool GateCongruenceClosure::DoOneRound(bool log_info) {
// TODO(user): Remove this condition, it is possible there are no binary
// and still gates!
if (implication_graph_->IsEmpty()) return true;
clause_manager_->DetachAllClauses();
PresolveTimer timer("GateCongruenceClosure", logger_, time_limit_);
timer.OverrideLogging(log_info);
@@ -2725,9 +2897,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
const LiteralIndex y =
negate ? rep_b.NegatedIndex() : rep_b.Index();
// TODO(user): We need to change the union_find algo used to be sure
// the that if rep(x) = y then rep(not(x)) = not(y), otherwise we
// might miss some reductions.
// Because x always refer to a and y to b, this should maintain
// the invariant root(lit) = root(lit.Negated()).Negated().
// This is checked below.
union_find.AddEdge(x.value(), y.value());
const LiteralIndex rep(union_find.FindRoot(y.value()));
const LiteralIndex to_merge = rep == x ? y : x;
@@ -2758,6 +2930,12 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
in_queue[gate_id.value()] = true;
}
}
// Invariant.
CHECK_EQ(
lrat_helper.GetRepresentativeWithProofSupport(rep_a),
lrat_helper.GetRepresentativeWithProofSupport(rep_a.Negated())
.Negated());
}
break;
}
@@ -2840,8 +3018,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// Generic "short" gates.
// We just take the representative and re-canonicalize.
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
CHECK_GE(gates_type_[id], 0);
CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
DCHECK_GE(gates_type_[id], 0);
DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
for (LiteralIndex& lit_ref : inputs) {
lit_ref =
lrat_helper.GetRepresentativeWithProofSupport(Literal(lit_ref))
@@ -2853,6 +3031,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
if (new_size < inputs.size()) {
gates_inputs_.Shrink(id, new_size);
}
DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
if (new_size == 1) {
// We have a function of size 1! This is an equivalence.
@@ -2863,7 +3042,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
break;
} else if (new_size == 0) {
// We have a fixed function! Just fix the literal.
CHECK(Literal(gates_target_[id]).IsPositive());
const Literal initial_to_fix =
(gates_type_[id] & 1) == 1 ? Literal(gates_target_[id])
: Literal(gates_target_[id]).Negated();

View File

@@ -500,6 +500,12 @@ class GateCongruenceClosure {
// functions of the form one_var = f(other_vars).
void ExtractShortGates(PresolveTimer& timer);
// Detects gates encoded in the given truth table, and add them to the set
// of gates. Returns the number of gate detected.
int ProcessTruthTable(absl::Span<const BooleanVariable> inputs,
SmallBitset truth_table,
absl::Span<const TruthTableId> ids_for_proof = {});
// Add a small clause to the corresponding truth table.
template <int arity>
void AddToTruthTable(SatClause* clause,
@@ -550,6 +556,7 @@ class GateCongruenceClosure {
// truth_tables_inputs_, this is a bit wasted but simplify the code.
absl::flat_hash_map<std::array<BooleanVariable, 3>, TruthTableId> ids3_;
absl::flat_hash_map<std::array<BooleanVariable, 4>, TruthTableId> ids4_;
absl::flat_hash_map<std::array<BooleanVariable, 5>, TruthTableId> ids5_;
CompactVectorVector<TruthTableId, BooleanVariable> truth_tables_inputs_;
util_intops::StrongVector<TruthTableId, SmallBitset> truth_tables_bitset_;
CompactVectorVector<TruthTableId, SatClause*> truth_tables_clauses_;

View File

@@ -134,6 +134,7 @@ class CompactVectorVector {
// Returns the previous size() as this is convenient for how we use it.
int Add(absl::Span<const V> values);
void AppendToLastVector(const V& value);
void AppendToLastVector(absl::Span<const V> values);
// Hacky: same as Add() but for sat::Literal or any type from which we can get
// a value type V via L.Index().value().
@@ -919,6 +920,13 @@ inline void CompactVectorVector<K, V>::AppendToLastVector(const V& value) {
buffer_.push_back(value);
}
template <typename K, typename V>
inline void CompactVectorVector<K, V>::AppendToLastVector(
absl::Span<const V> values) {
sizes_.back() += values.size();
buffer_.insert(buffer_.end(), values.begin(), values.end());
}
template <typename K, typename V>
inline void CompactVectorVector<K, V>::ReplaceValuesBySmallerSet(
K key, absl::Span<const V> values) {