[CP-SAT] more work on lrat, regroup linear1 presolve methods

This commit is contained in:
Laurent Perron
2025-12-27 11:58:47 +01:00
committed by Corentin Le Molgat
parent b28edf1d04
commit c8d7710fd7
15 changed files with 521 additions and 258 deletions

View File

@@ -2017,11 +2017,18 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
// been cleaned up yet, as these are needed to really recover all gates.
//
// TODO(user): Ideally the detection code should be robust to that.
// TODO(user): Maybe we should always have an hash-map of binary up to date?
int num_fn1 = 0;
std::vector<std::pair<Literal, Literal>> binary_used;
for (LiteralIndex a(0); a < implication_graph_->literal_size(); ++a) {
// TODO(user): If we know we have too many implications for the time limit
// We should just be better of not doing that loop at all.
if (timer.WorkLimitIsReached()) break;
if (implication_graph_->IsRedundant(Literal(a))) continue;
for (const Literal b : implication_graph_->Implications(Literal(a))) {
const absl::Span<const Literal> implied =
implication_graph_->Implications(Literal(a));
timer.TrackHashLookups(implied.size());
for (const Literal b : implied) {
if (implication_graph_->IsRedundant(b)) continue;
std::array<BooleanVariable, 2> key2;
@@ -2066,9 +2073,7 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables(
// The AND gate of size 3 should be detected by the short table code, no
// need to do the algo here which should be slower.
//
// TODO(user): This seems to be less strong. I think we have some bug
// in our fixed point loop when we fix variables.
continue;
} else if (clause->size() == 4) {
AddToTruthTable<4>(clause, ids4_);
} else if (clause->size() == 5) {
@@ -2867,6 +2872,7 @@ class LratGateCongruenceHelper {
implication_graph_->GetClauseId(target.Negated(), Literal(m_index)));
Append(clause_ids,
GetLiteralImpliesRepresentativeClauseId(Literal(m_index)));
Append(clause_ids, GetLiteralImpliesRepresentativeClauseId(target));
}
private:
@@ -2943,7 +2949,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
PresolveTimer timer("GateCongruenceClosure", logger_, time_limit_);
timer.OverrideLogging(log_info);
const int num_literals(sat_solver_->NumVariables() * 2);
const int num_variables(sat_solver_->NumVariables());
const int num_literals(num_variables * 2);
marked_.ClearAndResize(Literal(num_literals));
seen_.ClearAndResize(Literal(num_literals));
next_seen_.ClearAndResize(Literal(num_literals));
@@ -2955,7 +2962,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// Lets release the memory on exit.
CHECK(tmp_binary_clauses_.empty());
absl::Cleanup cleanup = [this] { tmp_binary_clauses_.clear(); };
absl::Cleanup binary_cleanup = [this] { tmp_binary_clauses_.clear(); };
ExtractAndGatesAndFillShortTruthTables(timer);
ExtractShortGates(timer);
@@ -2985,37 +2992,67 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// Tricky: we need to resize this to num_literals because the union_find that
// merges target can choose for a representative a literal that is not in the
// set of gate inputs.
MergeableOccurrenceList<LiteralIndex, GateId> input_literals_to_gate;
input_literals_to_gate.ResetFromTranspose(gates_inputs_, num_literals);
MergeableOccurrenceList<BooleanVariable, GateId> input_var_to_gate;
struct GetVarMapper {
BooleanVariable operator()(LiteralIndex l) const {
return Literal(l).Variable();
}
};
input_var_to_gate.ResetFromTransposeMap<GetVarMapper>(gates_inputs_,
num_variables);
LratGateCongruenceHelper lrat_helper(
trail_, implication_graph_, clause_manager_, clause_id_generator_,
lrat_proof_handler_, gates_target_, gates_clauses_, union_find);
// Stats + make sure we run it at exit.
int num_units = 0;
int num_equivalences = 0;
int num_processed = 0;
int arity1_equivalences = 0;
absl::Cleanup stat_cleanup = [&] {
total_wtime_ += timer.wtime();
total_dtime_ += timer.deterministic_time();
total_equivalences_ += num_equivalences;
total_num_units_ += num_units;
timer.AddCounter("processed", num_processed);
timer.AddCounter("units", num_units);
timer.AddCounter("f1_equiv", arity1_equivalences);
timer.AddCounter("equiv", num_equivalences);
};
// Starts with all gates in the queue.
const int num_gates = gates_inputs_.size();
total_gates_ += num_gates;
std::vector<bool> in_queue(num_gates, true);
std::vector<GateId> queue(num_gates);
for (GateId id(0); id < num_gates; ++id) queue[id.value()] = id;
int num_units = 0;
int num_processed_fixed_variables = trail_->Index();
const auto fix_literal = [&, this](Literal to_fix,
absl::Span<const ClauseId> clause_ids) {
DCHECK_EQ(to_fix, lrat_helper.GetRepresentativeWithProofSupport(to_fix));
if (assignment_.LiteralIsTrue(to_fix)) return true;
if (!clause_manager_->InprocessingFixLiteral(to_fix, clause_ids)) {
return false;
}
// This is quite tricky: as we fix a literal, we propagate right away
// everything implied by it in the binary implication graph. So we need to
// loop over all newly_fixed variable in order to properly reach the fix
// point!
++num_units;
for (const GateId gate_id : input_literals_to_gate[to_fix]) {
if (in_queue[gate_id.value()]) continue;
queue.push_back(gate_id);
in_queue[gate_id.value()] = true;
}
for (const GateId gate_id : input_literals_to_gate[to_fix.Negated()]) {
if (in_queue[gate_id.value()]) continue;
queue.push_back(gate_id);
in_queue[gate_id.value()] = true;
for (; num_processed_fixed_variables < trail_->Index();
++num_processed_fixed_variables) {
const Literal to_update = lrat_helper.GetRepresentativeWithProofSupport(
(*trail_)[num_processed_fixed_variables]);
for (const GateId gate_id : input_var_to_gate[to_update.Variable()]) {
if (in_queue[gate_id.value()]) continue;
queue.push_back(gate_id);
in_queue[gate_id.value()] = true;
}
input_var_to_gate.ClearList(to_update.Variable());
}
return true;
};
@@ -3025,7 +3062,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
return trail_->GetUnitClauseId(a.Variable());
};
int num_equivalences = 0;
const auto new_equivalence = [&, this](Literal a, Literal b,
ClauseId a_implies_b,
ClauseId b_implies_a) {
@@ -3052,6 +3088,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
return false;
}
BooleanVariable to_merge_var = kNoBooleanVariable;
BooleanVariable rep_var = kNoBooleanVariable;
for (const bool negate : {false, true}) {
const LiteralIndex x = negate ? a.NegatedIndex() : a.Index();
const LiteralIndex y = negate ? b.NegatedIndex() : b.Index();
@@ -3064,7 +3102,14 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
union_find.AddEdge(x.value(), y.value());
const LiteralIndex rep(union_find.FindRoot(y.value()));
const LiteralIndex to_merge = rep == x ? y : x;
input_literals_to_gate.MergeInto(to_merge, rep);
if (to_merge_var == kNoBooleanVariable) {
to_merge_var = Literal(to_merge).Variable();
rep_var = Literal(rep).Variable();
} else {
// We should have the same var.
CHECK_EQ(to_merge_var, Literal(to_merge).Variable());
CHECK_EQ(rep_var, Literal(rep).Variable());
}
if (lrat_proof_handler_ != nullptr) {
if (rep == x) {
@@ -3075,17 +3120,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
y_implies_x);
}
}
// Re-add to the queue all gates with touched inputs.
//
// TODO(user): I think we could only add the gates of "to_merge"
// before we merge. This part of the code is quite quick in any
// case.
for (const GateId gate_id : input_literals_to_gate[rep]) {
if (in_queue[gate_id.value()]) continue;
queue.push_back(gate_id);
in_queue[gate_id.value()] = true;
}
}
// Invariant.
@@ -3095,16 +3129,28 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
CHECK_EQ(
lrat_helper.GetRepresentativeWithProofSupport(b),
lrat_helper.GetRepresentativeWithProofSupport(b.Negated()).Negated());
// Re-add to the queue all gates with touched inputs.
//
// TODO(user): I think we could only add the gates of "to_merge"
// before we merge. This part of the code is quite quick in any
// case.
input_var_to_gate.MergeInto(to_merge_var, rep_var);
for (const GateId gate_id : input_var_to_gate[rep_var]) {
if (in_queue[gate_id.value()]) continue;
queue.push_back(gate_id);
in_queue[gate_id.value()] = true;
}
return true;
};
// Main loop.
int num_processed = 0;
int arity1_equivalences = 0;
while (!queue.empty()) {
++num_processed;
const GateId id = queue.back();
queue.pop_back();
CHECK(in_queue[id.value()]);
in_queue[id.value()] = false;
// Tricky: the hash-map might contain id not yet canonicalized. And in
@@ -3140,17 +3186,15 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
CHECK_NE(id, other_id);
CHECK_GE(other_id, 0);
CHECK_EQ(gates_type_[id], gates_type_[other_id]);
CHECK_EQ(absl::Span<const LiteralIndex>(gates_inputs_[id]),
absl::Span<const LiteralIndex>(gates_inputs_[other_id]));
CHECK_EQ(gates_inputs_[id], gates_inputs_[other_id]);
input_literals_to_gate.RemoveFromFutureOutput(id);
input_var_to_gate.RemoveFromFutureOutput(id);
// We detected a <=> b (or, equivalently, rep(a) <=> rep(b)).
const Literal a(gates_target_[id]);
const Literal b(gates_target_[other_id]);
const Literal rep_a = lrat_helper.GetRepresentativeWithProofSupport(a);
const Literal rep_b = lrat_helper.GetRepresentativeWithProofSupport(b);
if (rep_a != rep_b) {
ClauseId rep_a_implies_rep_b = kNoClauseId;
ClauseId rep_b_implies_rep_a = kNoClauseId;
@@ -3200,9 +3244,11 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// then target must be false.
if (marked_[Literal(rep).Negated()]) {
is_unit = true;
input_literals_to_gate.RemoveFromFutureOutput(id);
input_var_to_gate.RemoveFromFutureOutput(id);
const Literal to_fix = Literal(gates_target_[id]).Negated();
const Literal initial_to_fix = Literal(gates_target_[id]).Negated();
const Literal to_fix =
lrat_helper.GetRepresentativeWithProofSupport(initial_to_fix);
if (!assignment_.LiteralIsTrue(to_fix)) {
absl::InlinedVector<ClauseId, 4> clause_ids;
if (lrat_proof_handler_ != nullptr) {
@@ -3249,10 +3295,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
// Generic "short" gates.
// We just take the representative and re-canonicalize.
absl::Span<LiteralIndex> inputs = gates_inputs_[id];
DCHECK_GE(gates_type_[id], 0);
DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0);
for (LiteralIndex& lit_ref : inputs) {
DCHECK_EQ(gates_type_[id] >> (1 << (gates_inputs_[id].size())), 0);
for (LiteralIndex& lit_ref : gates_inputs_[id]) {
lit_ref =
lrat_helper.GetRepresentativeWithProofSupport(Literal(lit_ref))
.Index();
@@ -3261,7 +3306,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
const int new_size = CanonicalizeShortGate(id);
if (new_size == 1) {
// We have a function of size 1! This is an equivalence.
input_literals_to_gate.RemoveFromFutureOutput(id);
input_var_to_gate.RemoveFromFutureOutput(id);
const Literal a = Literal(gates_target_[id]);
const Literal b = Literal(gates_inputs_[id][0]);
const Literal rep_a = lrat_helper.GetRepresentativeWithProofSupport(a);
@@ -3277,7 +3322,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
break;
} else if (new_size == 0) {
// We have a fixed function! Just fix the literal.
input_literals_to_gate.RemoveFromFutureOutput(id);
input_var_to_gate.RemoveFromFutureOutput(id);
const Literal initial_to_fix =
(gates_type_[id] & 1) == 1 ? Literal(gates_target_[id])
: Literal(gates_target_[id]).Negated();
@@ -3293,16 +3338,44 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) {
}
}
total_wtime_ += timer.wtime();
total_dtime_ += timer.deterministic_time();
total_gates_ += num_gates;
total_equivalences_ += num_equivalences;
total_num_units_ += num_units;
// DEBUG check that we reached the fix point correctly.
if (DEBUG_MODE) {
CHECK(queue.empty());
gate_set.clear();
for (GateId id(0); id < num_gates; ++id) {
if (gates_type_[id] == kAndGateType) continue;
if (assignment_.LiteralIsAssigned(Literal(gates_target_[id]))) continue;
const int new_size = CanonicalizeShortGate(id);
if (new_size == 0) {
CHECK_EQ(gates_type_[id] & 1, 0);
const Literal initial_to_fix = Literal(gates_target_[id]).Negated();
const Literal to_fix =
lrat_helper.GetRepresentativeWithProofSupport(initial_to_fix);
CHECK(assignment_.LiteralIsTrue(to_fix));
} else if (new_size == 1) {
CHECK(!assignment_.LiteralIsAssigned(Literal(gates_target_[id])));
CHECK(!assignment_.LiteralIsAssigned(Literal(gates_inputs_[id][0])));
CHECK_EQ(lrat_helper.GetRepresentativeWithProofSupport(
Literal(gates_target_[id])),
lrat_helper.GetRepresentativeWithProofSupport(
Literal(gates_inputs_[id][0])))
<< id << " ";
} else {
const auto [it, inserted] = gate_set.insert(id);
if (!inserted) {
const GateId other_id = *it;
CHECK_EQ(lrat_helper.GetRepresentativeWithProofSupport(
Literal(gates_target_[id])),
lrat_helper.GetRepresentativeWithProofSupport(
Literal(gates_target_[other_id])))
<< id << " " << gates_inputs_[id] << " " << other_id << " "
<< gates_inputs_[other_id];
}
}
}
}
timer.AddCounter("arity1_equivalences", arity1_equivalences);
timer.AddCounter("units", num_units);
timer.AddCounter("processed", num_processed);
timer.AddCounter("equivalences", num_equivalences);
return true;
}