[CP-SAT] bug fixes, memory utilization reduction; more work on lrat
This commit is contained in:
committed by
Mizux Seiha
parent
469d83a2ef
commit
4a2de332ce
@@ -26,6 +26,11 @@ namespace operations_research::sat {
|
||||
|
||||
using SmallBitset = uint32_t;
|
||||
|
||||
// This works for num_bits == 32 too.
|
||||
inline SmallBitset GetNumBitsAtOne(int num_bits) {
|
||||
return ~SmallBitset(0) >> (32 - (1 << num_bits));
|
||||
}
|
||||
|
||||
// Sort the key and modify the truth table accordingly.
|
||||
//
|
||||
// Note that we don't deal with identical key here, but the function
|
||||
@@ -35,7 +40,6 @@ template <typename VarOrLiteral>
|
||||
void CanonicalizeTruthTable(absl::Span<VarOrLiteral> key,
|
||||
SmallBitset& bitmask) {
|
||||
const int num_bits = key.size();
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0);
|
||||
for (int i = 0; i < num_bits; ++i) {
|
||||
for (int j = i + 1; j < num_bits; ++j) {
|
||||
if (key[i] <= key[j]) continue;
|
||||
@@ -53,11 +57,9 @@ void CanonicalizeTruthTable(absl::Span<VarOrLiteral> key,
|
||||
new_bitmask |= ((bitmask >> p) & 1) << new_p;
|
||||
}
|
||||
bitmask = new_bitmask;
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0)
|
||||
<< i << " " << j << " " << num_bits;
|
||||
}
|
||||
}
|
||||
CHECK(std::is_sorted(key.begin(), key.end()));
|
||||
DCHECK(std::is_sorted(key.begin(), key.end()));
|
||||
}
|
||||
|
||||
// Given a clause, return the truth table corresponding to it.
|
||||
@@ -67,15 +69,14 @@ inline void FillKeyAndBitmask(absl::Span<const Literal> clause,
|
||||
SmallBitset& bitmask) {
|
||||
CHECK_EQ(clause.size(), key.size());
|
||||
const int num_bits = clause.size();
|
||||
bitmask = ~SmallBitset(0) >> (32 - (1 << num_bits)); // All allowed;
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0) << num_bits;
|
||||
SmallBitset bit_to_remove = 0;
|
||||
for (int i = 0; i < num_bits; ++i) {
|
||||
key[i] = clause[i].Variable();
|
||||
bit_to_remove |= (clause[i].IsPositive() ? 0 : 1) << i;
|
||||
}
|
||||
bitmask ^= (SmallBitset(1) << bit_to_remove);
|
||||
CHECK_EQ(bitmask >> (1 << num_bits), 0) << bit_to_remove << " " << num_bits;
|
||||
CHECK_LT(bit_to_remove, (1 << num_bits));
|
||||
bitmask = GetNumBitsAtOne(num_bits);
|
||||
bitmask ^= SmallBitset(1) << bit_to_remove;
|
||||
CanonicalizeTruthTable<BooleanVariable>(key, bitmask);
|
||||
}
|
||||
|
||||
@@ -111,15 +112,6 @@ inline int CanonicalizeFunctionTruthTable(LiteralIndex& target,
|
||||
|
||||
const int num_bits = inputs.size();
|
||||
CHECK_LE(num_bits, 4); // Truth table must fit on an int.
|
||||
const SmallBitset all_one = (1 << (1 << num_bits)) - 1;
|
||||
CHECK_EQ(function_values & ~all_one, 0);
|
||||
|
||||
// Make sure target is positive.
|
||||
if (!Literal(target).IsPositive()) {
|
||||
target = Literal(target).Negated();
|
||||
function_values = function_values ^ all_one;
|
||||
CHECK_EQ(function_values >> (1 << num_bits), 0);
|
||||
}
|
||||
|
||||
// Make sure all inputs are positive.
|
||||
for (int i = 0; i < num_bits; ++i) {
|
||||
@@ -189,6 +181,17 @@ inline int CanonicalizeFunctionTruthTable(LiteralIndex& target,
|
||||
}
|
||||
}
|
||||
|
||||
// If we have x = f(a,b,c) and not(y) = f(a,b,c) with the same f, we have an
|
||||
// equivalence, so we need to canonicallize both f() and not(f()) to the same
|
||||
// function. For that we just always choose to have the lowest bit at zero.
|
||||
if (function_values & 1) {
|
||||
target = Literal(target).Negated();
|
||||
const SmallBitset all_one = GetNumBitsAtOne(inputs.size());
|
||||
function_values = function_values ^ all_one;
|
||||
DCHECK_EQ(function_values >> (1 << inputs.size()), 0);
|
||||
}
|
||||
DCHECK_EQ(function_values & 1, 0);
|
||||
|
||||
int_function_values = function_values;
|
||||
return inputs.size();
|
||||
}
|
||||
|
||||
@@ -1667,10 +1667,6 @@ ReasonIndex IntegerTrail::AppendReasonToInternalBuffers(
|
||||
return reason_index;
|
||||
}
|
||||
|
||||
int64_t IntegerTrail::NextConflictId() {
|
||||
return sat_solver_->num_failures() + 1;
|
||||
}
|
||||
|
||||
bool IntegerTrail::EnqueueInternal(
|
||||
IntegerLiteral i_lit, bool use_lazy_reason,
|
||||
absl::Span<const Literal> literal_reason,
|
||||
@@ -1721,7 +1717,8 @@ bool IntegerTrail::EnqueueInternal(
|
||||
return false;
|
||||
}
|
||||
|
||||
MergeReasonIntoInternal(conflict, NextConflictId());
|
||||
// TODO(user): fix or remove the conflict_id optimization.
|
||||
MergeReasonIntoInternal(conflict, -1);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1793,7 +1790,8 @@ bool IntegerTrail::EnqueueInternal(
|
||||
return false;
|
||||
}
|
||||
|
||||
MergeReasonIntoInternal(conflict, NextConflictId());
|
||||
// TODO(user): fix or remove the conflict_id optimization.
|
||||
MergeReasonIntoInternal(conflict, -1);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2348,7 +2346,10 @@ absl::Span<const Literal> IntegerTrail::Reason(
|
||||
|
||||
tmp_queue_.push_back(prev_trail_index);
|
||||
}
|
||||
|
||||
// TODO(user): fix or remove the conflict_id optimization.
|
||||
// It is probably superseded by the new integer resolution in any case,
|
||||
// so I fill we could just remove it.
|
||||
MergeReasonIntoInternal(reason, -1);
|
||||
|
||||
if (DEBUG_MODE && debug_checker_ != nullptr) {
|
||||
|
||||
@@ -1090,9 +1090,6 @@ class IntegerTrail final : public SatPropagator {
|
||||
// Returns some debugging info.
|
||||
std::string DebugString();
|
||||
|
||||
// Used internally to return the next conflict number.
|
||||
int64_t NextConflictId();
|
||||
|
||||
// Information for each integer variable about its current lower bound and
|
||||
// position of the last TrailEntry in the trail referring to this var.
|
||||
util_intops::StrongVector<IntegerVariable, IntegerValue> var_lbs_;
|
||||
|
||||
@@ -122,7 +122,10 @@ void MinimizeCoreWithSearch(TimeLimit* limit, SatSolver* solver,
|
||||
const int old_size = core->size();
|
||||
std::vector<Literal> assumptions;
|
||||
absl::flat_hash_set<LiteralIndex> removed_once;
|
||||
while (true) {
|
||||
|
||||
// We stop as soon as the core size is one since there is nothing more
|
||||
// to minimize then.
|
||||
while (core->size() > 1) {
|
||||
if (limit->LimitReached()) break;
|
||||
|
||||
// Find a not yet removed literal to remove.
|
||||
|
||||
@@ -184,6 +184,46 @@ _IndexOrSeries = Union[pd.Index, pd.Series]
|
||||
|
||||
|
||||
# Helper functions.
|
||||
|
||||
|
||||
# warnings.deprecated is python3.13+. Not compatible with Open Source (3.10+).
|
||||
# pylint: disable=g-bare-generic
|
||||
def deprecated(message: str) -> Callable[[Callable], Callable]:
|
||||
"""Decorator that warns about a deprecated function."""
|
||||
|
||||
def deprecated_decorator(func) -> Callable:
|
||||
def deprecated_func(*args, **kwargs):
|
||||
warnings.warn(
|
||||
f"{func.__name__} is a deprecated function. {message}",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
warnings.simplefilter("default", DeprecationWarning)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return deprecated_func
|
||||
|
||||
return deprecated_decorator
|
||||
|
||||
|
||||
def deprecated_method(func, old_name: str) -> Callable:
|
||||
"""Wrapper that warns about a deprecated method."""
|
||||
|
||||
def deprecated_func(*args, **kwargs) -> Any:
|
||||
warnings.warn(
|
||||
f"{old_name} is a deprecated function. Use {func.__name__} instead.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
warnings.simplefilter("default", DeprecationWarning)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return deprecated_func
|
||||
|
||||
|
||||
# pylint: enable=g-bare-generic
|
||||
|
||||
|
||||
def snake_case_to_camel_case(name: str) -> str:
|
||||
"""Converts a snake_case name to CamelCase."""
|
||||
words = name.split("_")
|
||||
@@ -274,16 +314,6 @@ class CpModel(cmh.CpBaseModel):
|
||||
cmh.CpBaseModel.__init__(self, model_proto)
|
||||
self._add_pre_pep8_methods()
|
||||
|
||||
def _add_pre_pep8_methods(self) -> None:
|
||||
for method_name in dir(self):
|
||||
if callable(getattr(self, method_name)) and (
|
||||
method_name.startswith("add_")
|
||||
or method_name.startswith("new_")
|
||||
or method_name.startswith("clear_")
|
||||
):
|
||||
pre_pep8_name = snake_case_to_camel_case(method_name)
|
||||
setattr(self, pre_pep8_name, getattr(self, method_name))
|
||||
|
||||
# Naming.
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -1649,30 +1679,52 @@ class CpModel(cmh.CpBaseModel):
|
||||
# Compatibility with pre PEP8
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
def _add_pre_pep8_methods(self) -> None:
|
||||
for method_name in dir(self):
|
||||
if callable(getattr(self, method_name)) and (
|
||||
method_name.startswith("add_")
|
||||
or method_name.startswith("new_")
|
||||
or method_name.startswith("clear_")
|
||||
):
|
||||
pre_pep8_name = snake_case_to_camel_case(method_name)
|
||||
setattr(
|
||||
self,
|
||||
pre_pep8_name,
|
||||
deprecated_method(getattr(self, method_name), pre_pep8_name),
|
||||
)
|
||||
|
||||
for other_method_name in [
|
||||
"add",
|
||||
"clone",
|
||||
"get_bool_var_from_proto_index",
|
||||
"get_int_var_from_proto_index",
|
||||
"get_interval_var_from_proto_index",
|
||||
"minimize",
|
||||
"maximize",
|
||||
"has_objective",
|
||||
"model_stats",
|
||||
"validate",
|
||||
"export_to_file",
|
||||
]:
|
||||
pre_pep8_name = snake_case_to_camel_case(other_method_name)
|
||||
setattr(
|
||||
self,
|
||||
pre_pep8_name,
|
||||
deprecated_method(getattr(self, other_method_name), pre_pep8_name),
|
||||
)
|
||||
|
||||
@deprecated("Use name property instead.")
|
||||
def Name(self) -> str:
|
||||
return self.name
|
||||
|
||||
@deprecated("Use name property instead.")
|
||||
def SetName(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
@deprecated("Use proto property instead.")
|
||||
def Proto(self) -> cmh.CpModelProto:
|
||||
return self.proto
|
||||
|
||||
Add = add
|
||||
Clone = clone
|
||||
GetBoolVarFromProtoIndex = get_bool_var_from_proto_index
|
||||
GetIntVarFromProtoIndex = get_int_var_from_proto_index
|
||||
GetIntervalVarFromProtoIndex = get_interval_var_from_proto_index
|
||||
Minimize = minimize
|
||||
Maximize = maximize
|
||||
HasObjective = has_objective
|
||||
ModelStats = model_stats
|
||||
Validate = validate
|
||||
ExportToFile = export_to_file
|
||||
|
||||
# add_XXX, new_XXX, and clear_XXX methods are already duplicated
|
||||
# automatically.
|
||||
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
@@ -1924,49 +1976,85 @@ class CpSolver:
|
||||
# Compatibility with pre PEP8
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
@deprecated("Use best_objective_bound property instead.")
|
||||
def BestObjectiveBound(self) -> float:
|
||||
return self.best_objective_bound
|
||||
|
||||
BooleanValue = boolean_value
|
||||
BooleanValues = boolean_values
|
||||
@deprecated("Use boolean_value() method instead.")
|
||||
def BooleanValue(self, lit: LiteralT) -> bool:
|
||||
return self.boolean_value(lit)
|
||||
|
||||
@deprecated("Use boolean_values() method instead.")
|
||||
def BooleanValues(self, variables: _IndexOrSeries) -> pd.Series:
|
||||
return self.boolean_values(variables)
|
||||
|
||||
@deprecated("Use num_booleans property instead.")
|
||||
def NumBooleans(self) -> int:
|
||||
return self.num_booleans
|
||||
|
||||
@deprecated("Use num_conflicts property instead.")
|
||||
def NumConflicts(self) -> int:
|
||||
return self.num_conflicts
|
||||
|
||||
@deprecated("Use num_branches property instead.")
|
||||
def NumBranches(self) -> int:
|
||||
return self.num_branches
|
||||
|
||||
@deprecated("Use objective_value property instead.")
|
||||
def ObjectiveValue(self) -> float:
|
||||
return self.objective_value
|
||||
|
||||
@deprecated("Use response_proto property instead.")
|
||||
def ResponseProto(self) -> cmh.CpSolverResponse:
|
||||
return self.response_proto
|
||||
|
||||
ResponseStats = response_stats
|
||||
Solve = solve
|
||||
SolutionInfo = solution_info
|
||||
StatusName = status_name
|
||||
StopSearch = stop_search
|
||||
SufficientAssumptionsForInfeasibility = sufficient_assumptions_for_infeasibility
|
||||
@deprecated("Use response_stats() method instead.")
|
||||
def ResponseStats(self) -> str:
|
||||
return self.response_stats()
|
||||
|
||||
@deprecated("Use solve() method instead.")
|
||||
def Solve(
|
||||
self, model: CpModel, callback: "CpSolverSolutionCallback" = None
|
||||
) -> cmh.CpSolverStatus:
|
||||
return self.solve(model, callback)
|
||||
|
||||
@deprecated("Use solution_info() method instead.")
|
||||
def SolutionInfo(self) -> str:
|
||||
return self.solution_info()
|
||||
|
||||
@deprecated("Use status_name() method instead.")
|
||||
def StatusName(self, status: Optional[Any] = None) -> str:
|
||||
return self.status_name(status)
|
||||
|
||||
@deprecated("Use stop_search() method instead.")
|
||||
def StopSearch(self) -> None:
|
||||
self.stop_search()
|
||||
|
||||
@deprecated("Use sufficient_assumptions_for_infeasibility() method instead.")
|
||||
def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]:
|
||||
return self.sufficient_assumptions_for_infeasibility()
|
||||
|
||||
@deprecated("Use user_time property instead.")
|
||||
def UserTime(self) -> float:
|
||||
return self.user_time
|
||||
|
||||
Value = value
|
||||
Values = values
|
||||
@deprecated("Use value() method instead.")
|
||||
def Value(self, expression: LinearExprT) -> int:
|
||||
return self.value(expression)
|
||||
|
||||
@deprecated("Use values() method instead.")
|
||||
def Values(self, expressions: _IndexOrSeries) -> pd.Series:
|
||||
return self.values(expressions)
|
||||
|
||||
@deprecated("Use wall_time property instead.")
|
||||
def WallTime(self) -> float:
|
||||
return self.wall_time
|
||||
|
||||
@deprecated("Use solve() with enumerate_all_solutions = True.")
|
||||
def SearchForAllSolutions(
|
||||
self, model: CpModel, callback: "CpSolverSolutionCallback"
|
||||
) -> cmh.CpSolverStatus:
|
||||
"""DEPRECATED Use solve() with the right parameter.
|
||||
|
||||
Search for all solutions of a satisfiability problem.
|
||||
"""Search for all solutions of a satisfiability problem.
|
||||
|
||||
This method searches for all feasible solutions of a given model.
|
||||
Then it feeds the solution to the callback.
|
||||
@@ -1984,11 +2072,6 @@ class CpSolver:
|
||||
* *INFEASIBLE* if the solver has proved there are no solution
|
||||
* *OPTIMAL* if all solutions have been found
|
||||
"""
|
||||
warnings.warn(
|
||||
"search_for_all_solutions is deprecated; use solve() with"
|
||||
+ "enumerate_all_solutions = True.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if model.has_objective():
|
||||
raise TypeError(
|
||||
"Search for all solutions is only defined on satisfiability problems"
|
||||
|
||||
@@ -2767,10 +2767,13 @@ TRFM"""
|
||||
model = cp_model.CpModel()
|
||||
x = [model.NewBoolVar(f"x{i}") for i in range(5)]
|
||||
model.AddBoolOr(x)
|
||||
model.Maximize(sum(x))
|
||||
self.assertLen(model.proto.variables, 5)
|
||||
self.assertLen(model.proto.constraints, 1)
|
||||
self.assertLen(model.proto.constraints[0].bool_or.literals, 5)
|
||||
|
||||
self.assertTrue(hasattr(model, "Proto"))
|
||||
|
||||
model_copy = copy.copy(model)
|
||||
self.assertTrue(hasattr(model_copy, "AddBoolOr"))
|
||||
self.assertTrue(hasattr(model_copy, "AddBoolXOr"))
|
||||
|
||||
@@ -89,6 +89,7 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) {
|
||||
|
||||
// Mainly useful for development.
|
||||
double probing_time = 0.0;
|
||||
const bool log_info = VLOG_IS_ON(2);
|
||||
const bool log_round_info = VLOG_IS_ON(2);
|
||||
|
||||
// We currently do the transformations in a given order and restart each time
|
||||
@@ -140,6 +141,13 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(user): Think about the right order in this function.
|
||||
if (params_.inprocessing_use_congruence_closure()) {
|
||||
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
|
||||
RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables());
|
||||
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_info));
|
||||
}
|
||||
|
||||
// TODO(user): Combine the two? this way we don't create a full literal <->
|
||||
// clause graph twice. It might make sense to reach the BCE fix point which
|
||||
// is unique before each variable elimination.
|
||||
@@ -147,13 +155,6 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) {
|
||||
blocked_clause_simplifier_->DoOneRound(log_round_info);
|
||||
}
|
||||
|
||||
// TODO(user): Think about the right order in this function.
|
||||
if (params_.inprocessing_use_congruence_closure()) {
|
||||
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
|
||||
RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables());
|
||||
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_round_info));
|
||||
}
|
||||
|
||||
// TODO(user): this break some binary graph invariant. Fix!
|
||||
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
|
||||
RETURN_IF_FALSE(bounded_variable_elimination_->DoOneRound(log_round_info));
|
||||
@@ -298,7 +299,7 @@ bool Inprocessing::InprocessingRound() {
|
||||
if (params_.inprocessing_use_congruence_closure()) {
|
||||
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
|
||||
RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables());
|
||||
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_round_info));
|
||||
RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_info));
|
||||
}
|
||||
|
||||
RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info));
|
||||
@@ -1956,6 +1957,7 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
|
||||
PresolveTimer& timer) {
|
||||
ids3_.clear();
|
||||
ids4_.clear();
|
||||
ids5_.clear();
|
||||
truth_tables_inputs_.clear();
|
||||
truth_tables_bitset_.clear();
|
||||
truth_tables_clauses_.clear();
|
||||
@@ -1971,6 +1973,8 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
|
||||
AddToTruthTable<3>(clause, ids3_);
|
||||
} else if (clause->size() == 4) {
|
||||
AddToTruthTable<4>(clause, ids4_);
|
||||
} else if (clause->size() == 5) {
|
||||
AddToTruthTable<5>(clause, ids5_);
|
||||
}
|
||||
|
||||
// Used for an optimization below.
|
||||
@@ -2103,68 +2107,235 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
|
||||
timer.AddCounter("and_gates", gates_inputs_.size());
|
||||
}
|
||||
|
||||
int GateCongruenceClosure::ProcessTruthTable(
|
||||
absl::Span<const BooleanVariable> inputs, SmallBitset truth_table,
|
||||
absl::Span<const TruthTableId> ids_for_proof) {
|
||||
int num_detected = 0;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (!IsFunction(i, inputs.size(), truth_table)) continue;
|
||||
const int num_bits = inputs.size() - 1;
|
||||
++num_detected;
|
||||
|
||||
gates_target_.push_back(Literal(inputs[i], true));
|
||||
gates_inputs_.Add({});
|
||||
for (int j = 0; j < inputs.size(); ++j) {
|
||||
if (i != j) {
|
||||
gates_inputs_.AppendToLastVector(Literal(inputs[j], true));
|
||||
}
|
||||
}
|
||||
|
||||
// Generate the function truth table as a type.
|
||||
// We will canonicalize it further in the main loop.
|
||||
unsigned int type = 0;
|
||||
for (int p = 0; p < (1 << num_bits); ++p) {
|
||||
// Expand from (num_bits == inputs.size() - 1) bits to (inputs.size())
|
||||
// bits.
|
||||
const int bigger_p = AddHoleAtPosition(i, p);
|
||||
|
||||
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
|
||||
// target is 1 at this position.
|
||||
type |= 1 << p;
|
||||
DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function.
|
||||
} else {
|
||||
// Note that if there is no feasible assignment for a given p if
|
||||
// [(truth_table >> bigger_p) & 1] is not 1.
|
||||
//
|
||||
// Here we could learn smaller clause, but also we don't really care
|
||||
// what is the value of the target at that point p, so we use zero.
|
||||
//
|
||||
// TODO(user): This is not ideal if another function has a value of 1 at
|
||||
// point p, we could still merge it with this one. Shall we create two
|
||||
// gate type? or change the algo?
|
||||
}
|
||||
}
|
||||
|
||||
gates_type_.push_back(type);
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
gates_clauses_.Add({});
|
||||
for (const TruthTableId id : ids_for_proof) {
|
||||
gates_clauses_.AppendToLastVector(truth_tables_clauses_[id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_detected;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Return a BooleanVariable from b that is not in a or kNoBooleanVariable;
|
||||
BooleanVariable FindMissing(absl::Span<const BooleanVariable> vars_a,
|
||||
absl::Span<const BooleanVariable> vars_b) {
|
||||
for (const BooleanVariable b : vars_b) {
|
||||
bool found = false;
|
||||
for (const BooleanVariable a : vars_a) {
|
||||
if (a == b) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) return b;
|
||||
}
|
||||
return kNoBooleanVariable;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// TODO(user): It should be possible to extract ALL possible short gate, but
|
||||
// we are not there yet.
|
||||
void GateCongruenceClosure::ExtractShortGates(PresolveTimer& timer) {
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
truth_tables_clauses_.ResetFromFlatMapping(tmp_ids_, tmp_clauses_);
|
||||
truth_tables_clauses_.ResetFromFlatMapping(
|
||||
tmp_ids_, tmp_clauses_,
|
||||
/*minimum_num_nodes=*/truth_tables_bitset_.size());
|
||||
CHECK_EQ(truth_tables_bitset_.size(), truth_tables_clauses_.size());
|
||||
}
|
||||
|
||||
// This is used to combine two 3 arity table into one 4 arity one if
|
||||
// they share two variables.
|
||||
absl::flat_hash_map<std::array<BooleanVariable, 2>, int> binary_index_map;
|
||||
std::vector<int> flat_binary_index;
|
||||
std::vector<TruthTableId> flat_table_id;
|
||||
|
||||
// Counters.
|
||||
// We only fill a subset of the entries, but that makes the code shorter.
|
||||
std::vector<int> num_tables(5);
|
||||
std::vector<int> num_functions(5);
|
||||
// We only fill a subset of the entries, but that make the code shorter.
|
||||
std::vector<int> num_tables(6);
|
||||
std::vector<int> num_functions(6);
|
||||
|
||||
// Note that using the indirection via TruthTableId allow this code to
|
||||
// be deterministic.
|
||||
CHECK_EQ(truth_tables_bitset_.size(), truth_tables_inputs_.size());
|
||||
for (TruthTableId id(0); id < truth_tables_inputs_.size(); ++id) {
|
||||
const absl::Span<const BooleanVariable> inputs = truth_tables_inputs_[id];
|
||||
const SmallBitset truth_table = truth_tables_bitset_[id];
|
||||
++num_tables[inputs.size()];
|
||||
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (!IsFunction(i, inputs.size(), truth_table)) continue;
|
||||
|
||||
const int num_bits = inputs.size() - 1;
|
||||
++num_functions[num_bits];
|
||||
gates_target_.push_back(Literal(inputs[i], true));
|
||||
gates_inputs_.Add({});
|
||||
for (int j = 0; j < inputs.size(); ++j) {
|
||||
if (i != j) {
|
||||
gates_inputs_.AppendToLastVector(Literal(inputs[j], true));
|
||||
}
|
||||
// Given a table of arity 4, this merges all the information from the tables
|
||||
// of arity 3 included in it.
|
||||
int num_merges = 0;
|
||||
const auto merge3_into_4 = [this, &num_merges](
|
||||
absl::Span<const BooleanVariable> inputs,
|
||||
SmallBitset& truth_table,
|
||||
std::vector<TruthTableId>& ids_for_proof) {
|
||||
DCHECK_EQ(inputs.size(), 4);
|
||||
for (int i_to_remove = 0; i_to_remove < inputs.size(); ++i_to_remove) {
|
||||
int pos = 0;
|
||||
std::array<BooleanVariable, 3> key3;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (i == i_to_remove) continue;
|
||||
key3[pos++] = inputs[i];
|
||||
}
|
||||
const auto it = ids3_.find(key3);
|
||||
if (it == ids3_.end()) continue;
|
||||
++num_merges;
|
||||
|
||||
// Generate the function truth table as a type.
|
||||
// We will canonicalize it further in the main loop.
|
||||
unsigned int type = 0;
|
||||
for (int p = 0; p < (1 << num_bits); ++p) {
|
||||
// Expand from (num_bits == inputs.size() - 1) bits to (inputs.size())
|
||||
// bits.
|
||||
const int bigger_p = AddHoleAtPosition(i, p);
|
||||
// Extend the bitset from the table3 so that it is expressed correctly
|
||||
// for the given inputs.
|
||||
const TruthTableId id3 = it->second;
|
||||
std::array<BooleanVariable, 4> key4;
|
||||
for (int i = 0; i < 3; ++i) key4[i] = key3[i];
|
||||
key4[3] = FindMissing(key3, inputs);
|
||||
SmallBitset bitset = truth_tables_bitset_[id3];
|
||||
bitset |= bitset << (1 << 3); // Extend for a new variable.
|
||||
CanonicalizeTruthTable<BooleanVariable>(absl::MakeSpan(key4), bitset);
|
||||
CHECK_EQ(inputs, absl::MakeSpan(key4));
|
||||
truth_table &= bitset;
|
||||
|
||||
if ((truth_table >> (bigger_p + (1 << i))) & 1) {
|
||||
// target is 1 at this position.
|
||||
type |= 1 << p;
|
||||
DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function.
|
||||
} else {
|
||||
// Note that if there is no feasible assignment for a given p, first
|
||||
// we could have learned a smaller clause, but also we don't really
|
||||
// care what is the value of the target at that point p, so we use
|
||||
// zero.
|
||||
}
|
||||
}
|
||||
|
||||
gates_type_.push_back(type);
|
||||
// We need to add the corresponding clause!
|
||||
if (lrat_proof_handler_ != nullptr) {
|
||||
gates_clauses_.Add(truth_tables_clauses_[id]);
|
||||
ids_for_proof.push_back(id3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Starts by processing all existing tables.
|
||||
//
|
||||
// TODO(user): Since we deal with and_gates differently, do we need to look at
|
||||
// binary clauses here ? or that is not needed ? I think there is only two
|
||||
// kind of Boolean function on two inputs (and_gates, with any possible
|
||||
// negation) and xor_gate.
|
||||
std::vector<TruthTableId> ids_for_proof;
|
||||
for (TruthTableId t_id(0); t_id < truth_tables_inputs_.size(); ++t_id) {
|
||||
ids_for_proof.clear();
|
||||
ids_for_proof.push_back(t_id);
|
||||
const absl::Span<const BooleanVariable> inputs = truth_tables_inputs_[t_id];
|
||||
SmallBitset truth_table = truth_tables_bitset_[t_id];
|
||||
|
||||
// Merge any size-3 table included inside a size-4 gate.
|
||||
// TODO(user): do it for larger gate too ?
|
||||
if (inputs.size() == 4) {
|
||||
merge3_into_4(inputs, truth_table, ids_for_proof);
|
||||
}
|
||||
|
||||
++num_tables[inputs.size()];
|
||||
const int num_detected =
|
||||
ProcessTruthTable(inputs, truth_table, ids_for_proof);
|
||||
num_functions[inputs.size() - 1] += num_detected;
|
||||
|
||||
// If this is not a function and of size 3, lets try to combine it with
|
||||
// another truth table of size 3 to get a table of size 4.
|
||||
if (inputs.size() == 3 && num_detected == 0) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
std::array<BooleanVariable, 2> key{inputs[i != 0 ? 0 : 1],
|
||||
inputs[i != 2 ? 2 : 1]};
|
||||
DCHECK(std::is_sorted(key.begin(), key.end()));
|
||||
const auto [it, inserted] =
|
||||
binary_index_map.insert({key, binary_index_map.size()});
|
||||
flat_binary_index.push_back(it->second);
|
||||
flat_table_id.push_back(t_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
gtl::STLClearObject(&binary_index_map);
|
||||
|
||||
// Detects ITE gates and potentially other 3-gates from a truth table of
|
||||
// 4-entries formed by two 3-entries table. This just create a 4-entries
|
||||
// table that will be processed below.
|
||||
CompactVectorVector<int, TruthTableId> candidates;
|
||||
candidates.ResetFromFlatMapping(flat_binary_index, flat_table_id);
|
||||
gtl::STLClearObject(&flat_binary_index);
|
||||
gtl::STLClearObject(&flat_table_id);
|
||||
int num_combinations = 0;
|
||||
for (int c = 0; c < candidates.size(); ++c) {
|
||||
if (candidates[c].size() < 2) continue;
|
||||
if (candidates[c].size() > 10) continue; // Too many? use heuristic.
|
||||
|
||||
for (int a = 0; a < candidates[c].size(); ++a) {
|
||||
for (int b = a + 1; b < candidates[c].size(); ++b) {
|
||||
const absl::Span<const BooleanVariable> inputs_a =
|
||||
truth_tables_inputs_[candidates[c][a]];
|
||||
const absl::Span<const BooleanVariable> inputs_b =
|
||||
truth_tables_inputs_[candidates[c][b]];
|
||||
|
||||
std::array<BooleanVariable, 4> key;
|
||||
for (int i = 0; i < 3; ++i) key[i] = inputs_a[i];
|
||||
key[3] = FindMissing(inputs_a, inputs_b);
|
||||
CHECK_NE(key[3], kNoBooleanVariable);
|
||||
|
||||
// Add an all allowed entry.
|
||||
SmallBitset bitmask = GetNumBitsAtOne(4);
|
||||
CanonicalizeTruthTable<BooleanVariable>(absl::MakeSpan(key), bitmask);
|
||||
|
||||
// If the key was not processed before, process it now.
|
||||
// Note that an old version created a TruthTableId for it, but that
|
||||
// waste a lot of space.
|
||||
//
|
||||
// On another hand, it is possible we process the same key up to
|
||||
// 4_choose_2 times, but this is rare...
|
||||
if (!ids4_.contains(key)) {
|
||||
++num_combinations;
|
||||
++num_tables[4];
|
||||
ids_for_proof.clear();
|
||||
merge3_into_4(key, bitmask, ids_for_proof);
|
||||
num_functions[3] += ProcessTruthTable(key, bitmask, ids_for_proof);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
timer.AddCounter("combine3", num_combinations);
|
||||
timer.AddCounter("merges", num_merges);
|
||||
|
||||
// Note that we only display non-zero counters.
|
||||
for (int i = 2; i < 5; ++i) {
|
||||
timer.AddCounter(absl::StrCat("table", i), num_tables[i]);
|
||||
for (int i = 2; i < num_tables.size(); ++i) {
|
||||
timer.AddCounter(absl::StrCat("t", i), num_tables[i]);
|
||||
}
|
||||
for (int i = 2; i < num_functions.size(); ++i) {
|
||||
timer.AddCounter(absl::StrCat("fn", i), num_functions[i]);
|
||||
}
|
||||
}
|
||||
@@ -2583,8 +2754,9 @@ class LratGateCongruenceHelper {
|
||||
} // namespace
|
||||
|
||||
bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
// TODO(user): Remove this condition, it is possible there are no binary
|
||||
// and still gates!
|
||||
if (implication_graph_->IsEmpty()) return true;
|
||||
clause_manager_->DetachAllClauses();
|
||||
|
||||
PresolveTimer timer("GateCongruenceClosure", logger_, time_limit_);
|
||||
timer.OverrideLogging(log_info);
|
||||
@@ -2725,9 +2897,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
const LiteralIndex y =
|
||||
negate ? rep_b.NegatedIndex() : rep_b.Index();
|
||||
|
||||
// TODO(user): We need to change the union_find algo used to be sure
|
||||
// the that if rep(x) = y then rep(not(x)) = not(y), otherwise we
|
||||
// might miss some reductions.
|
||||
// Because x always refer to a and y to b, this should maintain
|
||||
// the invariant root(lit) = root(lit.Negated()).Negated().
|
||||
// This is checked below.
|
||||
union_find.AddEdge(x.value(), y.value());
|
||||
const LiteralIndex rep(union_find.FindRoot(y.value()));
|
||||
const LiteralIndex to_merge = rep == x ? y : x;
|
||||
@@ -2758,6 +2930,12 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
in_queue[gate_id.value()] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Invariant.
|
||||
CHECK_EQ(
|
||||
lrat_helper.GetRepresentativeWithProofSupport(rep_a),
|
||||
lrat_helper.GetRepresentativeWithProofSupport(rep_a.Negated())
|
||||
.Negated());
|
||||
}
|
||||
break;
|
||||
}
|
||||
@@ -2840,8 +3018,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
// Generic "short" gates.
|
||||
// We just take the representative and re-canonicalize.
|
||||
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
|
||||
CHECK_GE(gates_type_[id], 0);
|
||||
CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
|
||||
DCHECK_GE(gates_type_[id], 0);
|
||||
DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
|
||||
for (LiteralIndex& lit_ref : inputs) {
|
||||
lit_ref =
|
||||
lrat_helper.GetRepresentativeWithProofSupport(Literal(lit_ref))
|
||||
@@ -2853,6 +3031,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
if (new_size < inputs.size()) {
|
||||
gates_inputs_.Shrink(id, new_size);
|
||||
}
|
||||
DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
|
||||
|
||||
if (new_size == 1) {
|
||||
// We have a function of size 1! This is an equivalence.
|
||||
@@ -2863,7 +3042,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
|
||||
break;
|
||||
} else if (new_size == 0) {
|
||||
// We have a fixed function! Just fix the literal.
|
||||
CHECK(Literal(gates_target_[id]).IsPositive());
|
||||
const Literal initial_to_fix =
|
||||
(gates_type_[id] & 1) == 1 ? Literal(gates_target_[id])
|
||||
: Literal(gates_target_[id]).Negated();
|
||||
|
||||
@@ -500,6 +500,12 @@ class GateCongruenceClosure {
|
||||
// functions of the form one_var = f(other_vars).
|
||||
void ExtractShortGates(PresolveTimer& timer);
|
||||
|
||||
// Detects gates encoded in the given truth table, and add them to the set
|
||||
// of gates. Returns the number of gate detected.
|
||||
int ProcessTruthTable(absl::Span<const BooleanVariable> inputs,
|
||||
SmallBitset truth_table,
|
||||
absl::Span<const TruthTableId> ids_for_proof = {});
|
||||
|
||||
// Add a small clause to the corresponding truth table.
|
||||
template <int arity>
|
||||
void AddToTruthTable(SatClause* clause,
|
||||
@@ -550,6 +556,7 @@ class GateCongruenceClosure {
|
||||
// truth_tables_inputs_, this is a bit wasted but simplify the code.
|
||||
absl::flat_hash_map<std::array<BooleanVariable, 3>, TruthTableId> ids3_;
|
||||
absl::flat_hash_map<std::array<BooleanVariable, 4>, TruthTableId> ids4_;
|
||||
absl::flat_hash_map<std::array<BooleanVariable, 5>, TruthTableId> ids5_;
|
||||
CompactVectorVector<TruthTableId, BooleanVariable> truth_tables_inputs_;
|
||||
util_intops::StrongVector<TruthTableId, SmallBitset> truth_tables_bitset_;
|
||||
CompactVectorVector<TruthTableId, SatClause*> truth_tables_clauses_;
|
||||
|
||||
@@ -134,6 +134,7 @@ class CompactVectorVector {
|
||||
// Returns the previous size() as this is convenient for how we use it.
|
||||
int Add(absl::Span<const V> values);
|
||||
void AppendToLastVector(const V& value);
|
||||
void AppendToLastVector(absl::Span<const V> values);
|
||||
|
||||
// Hacky: same as Add() but for sat::Literal or any type from which we can get
|
||||
// a value type V via L.Index().value().
|
||||
@@ -919,6 +920,13 @@ inline void CompactVectorVector<K, V>::AppendToLastVector(const V& value) {
|
||||
buffer_.push_back(value);
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
inline void CompactVectorVector<K, V>::AppendToLastVector(
|
||||
absl::Span<const V> values) {
|
||||
sizes_.back() += values.size();
|
||||
buffer_.insert(buffer_.end(), values.begin(), values.end());
|
||||
}
|
||||
|
||||
template <typename K, typename V>
|
||||
inline void CompactVectorVector<K, V>::ReplaceValuesBySmallerSet(
|
||||
K key, absl::Span<const V> values) {
|
||||
|
||||
Reference in New Issue
Block a user