diff --git a/ortools/sat/gate_utils.h b/ortools/sat/gate_utils.h index 99301bf8a1..d902c57977 100644 --- a/ortools/sat/gate_utils.h +++ b/ortools/sat/gate_utils.h @@ -26,6 +26,11 @@ namespace operations_research::sat { using SmallBitset = uint32_t; +// This works for num_bits == 32 too. +inline SmallBitset GetNumBitsAtOne(int num_bits) { + return ~SmallBitset(0) >> (32 - (1 << num_bits)); +} + // Sort the key and modify the truth table accordingly. // // Note that we don't deal with identical key here, but the function @@ -35,7 +40,6 @@ template void CanonicalizeTruthTable(absl::Span key, SmallBitset& bitmask) { const int num_bits = key.size(); - CHECK_EQ(bitmask >> (1 << num_bits), 0); for (int i = 0; i < num_bits; ++i) { for (int j = i + 1; j < num_bits; ++j) { if (key[i] <= key[j]) continue; @@ -53,11 +57,9 @@ void CanonicalizeTruthTable(absl::Span key, new_bitmask |= ((bitmask >> p) & 1) << new_p; } bitmask = new_bitmask; - CHECK_EQ(bitmask >> (1 << num_bits), 0) - << i << " " << j << " " << num_bits; } } - CHECK(std::is_sorted(key.begin(), key.end())); + DCHECK(std::is_sorted(key.begin(), key.end())); } // Given a clause, return the truth table corresponding to it. @@ -67,15 +69,14 @@ inline void FillKeyAndBitmask(absl::Span clause, SmallBitset& bitmask) { CHECK_EQ(clause.size(), key.size()); const int num_bits = clause.size(); - bitmask = ~SmallBitset(0) >> (32 - (1 << num_bits)); // All allowed; - CHECK_EQ(bitmask >> (1 << num_bits), 0) << num_bits; SmallBitset bit_to_remove = 0; for (int i = 0; i < num_bits; ++i) { key[i] = clause[i].Variable(); bit_to_remove |= (clause[i].IsPositive() ? 0 : 1) << i; } - bitmask ^= (SmallBitset(1) << bit_to_remove); - CHECK_EQ(bitmask >> (1 << num_bits), 0) << bit_to_remove << " " << num_bits; + CHECK_LT(bit_to_remove, (1 << num_bits)); + bitmask = GetNumBitsAtOne(num_bits); + bitmask ^= SmallBitset(1) << bit_to_remove; CanonicalizeTruthTable(key, bitmask); } @@ -111,15 +112,6 @@ inline int CanonicalizeFunctionTruthTable(LiteralIndex& target, const int num_bits = inputs.size(); CHECK_LE(num_bits, 4); // Truth table must fit on an int. - const SmallBitset all_one = (1 << (1 << num_bits)) - 1; - CHECK_EQ(function_values & ~all_one, 0); - - // Make sure target is positive. - if (!Literal(target).IsPositive()) { - target = Literal(target).Negated(); - function_values = function_values ^ all_one; - CHECK_EQ(function_values >> (1 << num_bits), 0); - } // Make sure all inputs are positive. for (int i = 0; i < num_bits; ++i) { @@ -189,6 +181,17 @@ inline int CanonicalizeFunctionTruthTable(LiteralIndex& target, } } + // If we have x = f(a,b,c) and not(y) = f(a,b,c) with the same f, we have an + // equivalence, so we need to canonicallize both f() and not(f()) to the same + // function. For that we just always choose to have the lowest bit at zero. + if (function_values & 1) { + target = Literal(target).Negated(); + const SmallBitset all_one = GetNumBitsAtOne(inputs.size()); + function_values = function_values ^ all_one; + DCHECK_EQ(function_values >> (1 << inputs.size()), 0); + } + DCHECK_EQ(function_values & 1, 0); + int_function_values = function_values; return inputs.size(); } diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 37f46ac6c3..573324e538 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -1667,10 +1667,6 @@ ReasonIndex IntegerTrail::AppendReasonToInternalBuffers( return reason_index; } -int64_t IntegerTrail::NextConflictId() { - return sat_solver_->num_failures() + 1; -} - bool IntegerTrail::EnqueueInternal( IntegerLiteral i_lit, bool use_lazy_reason, absl::Span literal_reason, @@ -1721,7 +1717,8 @@ bool IntegerTrail::EnqueueInternal( return false; } - MergeReasonIntoInternal(conflict, NextConflictId()); + // TODO(user): fix or remove the conflict_id optimization. + MergeReasonIntoInternal(conflict, -1); return false; } @@ -1793,7 +1790,8 @@ bool IntegerTrail::EnqueueInternal( return false; } - MergeReasonIntoInternal(conflict, NextConflictId()); + // TODO(user): fix or remove the conflict_id optimization. + MergeReasonIntoInternal(conflict, -1); return false; } @@ -2348,7 +2346,10 @@ absl::Span IntegerTrail::Reason( tmp_queue_.push_back(prev_trail_index); } + // TODO(user): fix or remove the conflict_id optimization. + // It is probably superseded by the new integer resolution in any case, + // so I fill we could just remove it. MergeReasonIntoInternal(reason, -1); if (DEBUG_MODE && debug_checker_ != nullptr) { diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 96fdebf1e5..9ff965d6e0 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -1090,9 +1090,6 @@ class IntegerTrail final : public SatPropagator { // Returns some debugging info. std::string DebugString(); - // Used internally to return the next conflict number. - int64_t NextConflictId(); - // Information for each integer variable about its current lower bound and // position of the last TrailEntry in the trail referring to this var. util_intops::StrongVector var_lbs_; diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index 53c9c97c88..8a909e2272 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -122,7 +122,10 @@ void MinimizeCoreWithSearch(TimeLimit* limit, SatSolver* solver, const int old_size = core->size(); std::vector assumptions; absl::flat_hash_set removed_once; - while (true) { + + // We stop as soon as the core size is one since there is nothing more + // to minimize then. + while (core->size() > 1) { if (limit->LimitReached()) break; // Find a not yet removed literal to remove. diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index dfc01b3432..44eab008b8 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -184,6 +184,46 @@ _IndexOrSeries = Union[pd.Index, pd.Series] # Helper functions. + + +# warnings.deprecated is python3.13+. Not compatible with Open Source (3.10+). +# pylint: disable=g-bare-generic +def deprecated(message: str) -> Callable[[Callable], Callable]: + """Decorator that warns about a deprecated function.""" + + def deprecated_decorator(func) -> Callable: + def deprecated_func(*args, **kwargs): + warnings.warn( + f"{func.__name__} is a deprecated function. {message}", + category=DeprecationWarning, + stacklevel=2, + ) + warnings.simplefilter("default", DeprecationWarning) + return func(*args, **kwargs) + + return deprecated_func + + return deprecated_decorator + + +def deprecated_method(func, old_name: str) -> Callable: + """Wrapper that warns about a deprecated method.""" + + def deprecated_func(*args, **kwargs) -> Any: + warnings.warn( + f"{old_name} is a deprecated function. Use {func.__name__} instead.", + category=DeprecationWarning, + stacklevel=2, + ) + warnings.simplefilter("default", DeprecationWarning) + return func(*args, **kwargs) + + return deprecated_func + + +# pylint: enable=g-bare-generic + + def snake_case_to_camel_case(name: str) -> str: """Converts a snake_case name to CamelCase.""" words = name.split("_") @@ -274,16 +314,6 @@ class CpModel(cmh.CpBaseModel): cmh.CpBaseModel.__init__(self, model_proto) self._add_pre_pep8_methods() - def _add_pre_pep8_methods(self) -> None: - for method_name in dir(self): - if callable(getattr(self, method_name)) and ( - method_name.startswith("add_") - or method_name.startswith("new_") - or method_name.startswith("clear_") - ): - pre_pep8_name = snake_case_to_camel_case(method_name) - setattr(self, pre_pep8_name, getattr(self, method_name)) - # Naming. @property def name(self) -> str: @@ -1649,30 +1679,52 @@ class CpModel(cmh.CpBaseModel): # Compatibility with pre PEP8 # pylint: disable=invalid-name + def _add_pre_pep8_methods(self) -> None: + for method_name in dir(self): + if callable(getattr(self, method_name)) and ( + method_name.startswith("add_") + or method_name.startswith("new_") + or method_name.startswith("clear_") + ): + pre_pep8_name = snake_case_to_camel_case(method_name) + setattr( + self, + pre_pep8_name, + deprecated_method(getattr(self, method_name), pre_pep8_name), + ) + + for other_method_name in [ + "add", + "clone", + "get_bool_var_from_proto_index", + "get_int_var_from_proto_index", + "get_interval_var_from_proto_index", + "minimize", + "maximize", + "has_objective", + "model_stats", + "validate", + "export_to_file", + ]: + pre_pep8_name = snake_case_to_camel_case(other_method_name) + setattr( + self, + pre_pep8_name, + deprecated_method(getattr(self, other_method_name), pre_pep8_name), + ) + + @deprecated("Use name property instead.") def Name(self) -> str: return self.name + @deprecated("Use name property instead.") def SetName(self, name: str) -> None: self.name = name + @deprecated("Use proto property instead.") def Proto(self) -> cmh.CpModelProto: return self.proto - Add = add - Clone = clone - GetBoolVarFromProtoIndex = get_bool_var_from_proto_index - GetIntVarFromProtoIndex = get_int_var_from_proto_index - GetIntervalVarFromProtoIndex = get_interval_var_from_proto_index - Minimize = minimize - Maximize = maximize - HasObjective = has_objective - ModelStats = model_stats - Validate = validate - ExportToFile = export_to_file - - # add_XXX, new_XXX, and clear_XXX methods are already duplicated - # automatically. - # pylint: enable=invalid-name @@ -1924,49 +1976,85 @@ class CpSolver: # Compatibility with pre PEP8 # pylint: disable=invalid-name + @deprecated("Use best_objective_bound property instead.") def BestObjectiveBound(self) -> float: return self.best_objective_bound - BooleanValue = boolean_value - BooleanValues = boolean_values + @deprecated("Use boolean_value() method instead.") + def BooleanValue(self, lit: LiteralT) -> bool: + return self.boolean_value(lit) + @deprecated("Use boolean_values() method instead.") + def BooleanValues(self, variables: _IndexOrSeries) -> pd.Series: + return self.boolean_values(variables) + + @deprecated("Use num_booleans property instead.") def NumBooleans(self) -> int: return self.num_booleans + @deprecated("Use num_conflicts property instead.") def NumConflicts(self) -> int: return self.num_conflicts + @deprecated("Use num_branches property instead.") def NumBranches(self) -> int: return self.num_branches + @deprecated("Use objective_value property instead.") def ObjectiveValue(self) -> float: return self.objective_value + @deprecated("Use response_proto property instead.") def ResponseProto(self) -> cmh.CpSolverResponse: return self.response_proto - ResponseStats = response_stats - Solve = solve - SolutionInfo = solution_info - StatusName = status_name - StopSearch = stop_search - SufficientAssumptionsForInfeasibility = sufficient_assumptions_for_infeasibility + @deprecated("Use response_stats() method instead.") + def ResponseStats(self) -> str: + return self.response_stats() + @deprecated("Use solve() method instead.") + def Solve( + self, model: CpModel, callback: "CpSolverSolutionCallback" = None + ) -> cmh.CpSolverStatus: + return self.solve(model, callback) + + @deprecated("Use solution_info() method instead.") + def SolutionInfo(self) -> str: + return self.solution_info() + + @deprecated("Use status_name() method instead.") + def StatusName(self, status: Optional[Any] = None) -> str: + return self.status_name(status) + + @deprecated("Use stop_search() method instead.") + def StopSearch(self) -> None: + self.stop_search() + + @deprecated("Use sufficient_assumptions_for_infeasibility() method instead.") + def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]: + return self.sufficient_assumptions_for_infeasibility() + + @deprecated("Use user_time property instead.") def UserTime(self) -> float: return self.user_time - Value = value - Values = values + @deprecated("Use value() method instead.") + def Value(self, expression: LinearExprT) -> int: + return self.value(expression) + @deprecated("Use values() method instead.") + def Values(self, expressions: _IndexOrSeries) -> pd.Series: + return self.values(expressions) + + @deprecated("Use wall_time property instead.") def WallTime(self) -> float: return self.wall_time + @deprecated("Use solve() with enumerate_all_solutions = True.") def SearchForAllSolutions( self, model: CpModel, callback: "CpSolverSolutionCallback" ) -> cmh.CpSolverStatus: - """DEPRECATED Use solve() with the right parameter. - - Search for all solutions of a satisfiability problem. + """Search for all solutions of a satisfiability problem. This method searches for all feasible solutions of a given model. Then it feeds the solution to the callback. @@ -1984,11 +2072,6 @@ class CpSolver: * *INFEASIBLE* if the solver has proved there are no solution * *OPTIMAL* if all solutions have been found """ - warnings.warn( - "search_for_all_solutions is deprecated; use solve() with" - + "enumerate_all_solutions = True.", - DeprecationWarning, - ) if model.has_objective(): raise TypeError( "Search for all solutions is only defined on satisfiability problems" diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index 92505d347c..4d6ae57539 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -2767,10 +2767,13 @@ TRFM""" model = cp_model.CpModel() x = [model.NewBoolVar(f"x{i}") for i in range(5)] model.AddBoolOr(x) + model.Maximize(sum(x)) self.assertLen(model.proto.variables, 5) self.assertLen(model.proto.constraints, 1) self.assertLen(model.proto.constraints[0].bool_or.literals, 5) + self.assertTrue(hasattr(model, "Proto")) + model_copy = copy.copy(model) self.assertTrue(hasattr(model_copy, "AddBoolOr")) self.assertTrue(hasattr(model_copy, "AddBoolXOr")) diff --git a/ortools/sat/sat_inprocessing.cc b/ortools/sat/sat_inprocessing.cc index 925137301c..8b93a5ab6b 100644 --- a/ortools/sat/sat_inprocessing.cc +++ b/ortools/sat/sat_inprocessing.cc @@ -89,6 +89,7 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) { // Mainly useful for development. double probing_time = 0.0; + const bool log_info = VLOG_IS_ON(2); const bool log_round_info = VLOG_IS_ON(2); // We currently do the transformations in a given order and restart each time @@ -140,6 +141,13 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) { continue; } + // TODO(user): Think about the right order in this function. + if (params_.inprocessing_use_congruence_closure()) { + RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info)); + RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables()); + RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_info)); + } + // TODO(user): Combine the two? this way we don't create a full literal <-> // clause graph twice. It might make sense to reach the BCE fix point which // is unique before each variable elimination. @@ -147,13 +155,6 @@ bool Inprocessing::PresolveLoop(SatPresolveOptions options) { blocked_clause_simplifier_->DoOneRound(log_round_info); } - // TODO(user): Think about the right order in this function. - if (params_.inprocessing_use_congruence_closure()) { - RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info)); - RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables()); - RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_round_info)); - } - // TODO(user): this break some binary graph invariant. Fix! RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info)); RETURN_IF_FALSE(bounded_variable_elimination_->DoOneRound(log_round_info)); @@ -298,7 +299,7 @@ bool Inprocessing::InprocessingRound() { if (params_.inprocessing_use_congruence_closure()) { RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info)); RETURN_IF_FALSE(implication_graph_->RemoveDuplicatesAndFixedVariables()); - RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_round_info)); + RETURN_IF_FALSE(congruence_closure_->DoOneRound(log_info)); } RETURN_IF_FALSE(RemoveFixedAndEquivalentVariables(log_round_info)); @@ -1956,6 +1957,7 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( PresolveTimer& timer) { ids3_.clear(); ids4_.clear(); + ids5_.clear(); truth_tables_inputs_.clear(); truth_tables_bitset_.clear(); truth_tables_clauses_.clear(); @@ -1971,6 +1973,8 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( AddToTruthTable<3>(clause, ids3_); } else if (clause->size() == 4) { AddToTruthTable<4>(clause, ids4_); + } else if (clause->size() == 5) { + AddToTruthTable<5>(clause, ids5_); } // Used for an optimization below. @@ -2103,68 +2107,235 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( timer.AddCounter("and_gates", gates_inputs_.size()); } +int GateCongruenceClosure::ProcessTruthTable( + absl::Span inputs, SmallBitset truth_table, + absl::Span ids_for_proof) { + int num_detected = 0; + for (int i = 0; i < inputs.size(); ++i) { + if (!IsFunction(i, inputs.size(), truth_table)) continue; + const int num_bits = inputs.size() - 1; + ++num_detected; + + gates_target_.push_back(Literal(inputs[i], true)); + gates_inputs_.Add({}); + for (int j = 0; j < inputs.size(); ++j) { + if (i != j) { + gates_inputs_.AppendToLastVector(Literal(inputs[j], true)); + } + } + + // Generate the function truth table as a type. + // We will canonicalize it further in the main loop. + unsigned int type = 0; + for (int p = 0; p < (1 << num_bits); ++p) { + // Expand from (num_bits == inputs.size() - 1) bits to (inputs.size()) + // bits. + const int bigger_p = AddHoleAtPosition(i, p); + + if ((truth_table >> (bigger_p + (1 << i))) & 1) { + // target is 1 at this position. + type |= 1 << p; + DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function. + } else { + // Note that if there is no feasible assignment for a given p if + // [(truth_table >> bigger_p) & 1] is not 1. + // + // Here we could learn smaller clause, but also we don't really care + // what is the value of the target at that point p, so we use zero. + // + // TODO(user): This is not ideal if another function has a value of 1 at + // point p, we could still merge it with this one. Shall we create two + // gate type? or change the algo? + } + } + + gates_type_.push_back(type); + if (lrat_proof_handler_ != nullptr) { + gates_clauses_.Add({}); + for (const TruthTableId id : ids_for_proof) { + gates_clauses_.AppendToLastVector(truth_tables_clauses_[id]); + } + } + } + return num_detected; +} + +namespace { + +// Return a BooleanVariable from b that is not in a or kNoBooleanVariable; +BooleanVariable FindMissing(absl::Span vars_a, + absl::Span vars_b) { + for (const BooleanVariable b : vars_b) { + bool found = false; + for (const BooleanVariable a : vars_a) { + if (a == b) { + found = true; + break; + } + } + if (!found) return b; + } + return kNoBooleanVariable; +} + +} // namespace + +// TODO(user): It should be possible to extract ALL possible short gate, but +// we are not there yet. void GateCongruenceClosure::ExtractShortGates(PresolveTimer& timer) { if (lrat_proof_handler_ != nullptr) { - truth_tables_clauses_.ResetFromFlatMapping(tmp_ids_, tmp_clauses_); + truth_tables_clauses_.ResetFromFlatMapping( + tmp_ids_, tmp_clauses_, + /*minimum_num_nodes=*/truth_tables_bitset_.size()); CHECK_EQ(truth_tables_bitset_.size(), truth_tables_clauses_.size()); } + // This is used to combine two 3 arity table into one 4 arity one if + // they share two variables. + absl::flat_hash_map, int> binary_index_map; + std::vector flat_binary_index; + std::vector flat_table_id; + // Counters. - // We only fill a subset of the entries, but that makes the code shorter. - std::vector num_tables(5); - std::vector num_functions(5); + // We only fill a subset of the entries, but that make the code shorter. + std::vector num_tables(6); + std::vector num_functions(6); // Note that using the indirection via TruthTableId allow this code to // be deterministic. CHECK_EQ(truth_tables_bitset_.size(), truth_tables_inputs_.size()); - for (TruthTableId id(0); id < truth_tables_inputs_.size(); ++id) { - const absl::Span inputs = truth_tables_inputs_[id]; - const SmallBitset truth_table = truth_tables_bitset_[id]; - ++num_tables[inputs.size()]; - for (int i = 0; i < inputs.size(); ++i) { - if (!IsFunction(i, inputs.size(), truth_table)) continue; - - const int num_bits = inputs.size() - 1; - ++num_functions[num_bits]; - gates_target_.push_back(Literal(inputs[i], true)); - gates_inputs_.Add({}); - for (int j = 0; j < inputs.size(); ++j) { - if (i != j) { - gates_inputs_.AppendToLastVector(Literal(inputs[j], true)); - } + // Given a table of arity 4, this merges all the information from the tables + // of arity 3 included in it. + int num_merges = 0; + const auto merge3_into_4 = [this, &num_merges]( + absl::Span inputs, + SmallBitset& truth_table, + std::vector& ids_for_proof) { + DCHECK_EQ(inputs.size(), 4); + for (int i_to_remove = 0; i_to_remove < inputs.size(); ++i_to_remove) { + int pos = 0; + std::array key3; + for (int i = 0; i < inputs.size(); ++i) { + if (i == i_to_remove) continue; + key3[pos++] = inputs[i]; } + const auto it = ids3_.find(key3); + if (it == ids3_.end()) continue; + ++num_merges; - // Generate the function truth table as a type. - // We will canonicalize it further in the main loop. - unsigned int type = 0; - for (int p = 0; p < (1 << num_bits); ++p) { - // Expand from (num_bits == inputs.size() - 1) bits to (inputs.size()) - // bits. - const int bigger_p = AddHoleAtPosition(i, p); + // Extend the bitset from the table3 so that it is expressed correctly + // for the given inputs. + const TruthTableId id3 = it->second; + std::array key4; + for (int i = 0; i < 3; ++i) key4[i] = key3[i]; + key4[3] = FindMissing(key3, inputs); + SmallBitset bitset = truth_tables_bitset_[id3]; + bitset |= bitset << (1 << 3); // Extend for a new variable. + CanonicalizeTruthTable(absl::MakeSpan(key4), bitset); + CHECK_EQ(inputs, absl::MakeSpan(key4)); + truth_table &= bitset; - if ((truth_table >> (bigger_p + (1 << i))) & 1) { - // target is 1 at this position. - type |= 1 << p; - DCHECK_NE((truth_table >> bigger_p) & 1, 1); // Proper function. - } else { - // Note that if there is no feasible assignment for a given p, first - // we could have learned a smaller clause, but also we don't really - // care what is the value of the target at that point p, so we use - // zero. - } - } - - gates_type_.push_back(type); + // We need to add the corresponding clause! if (lrat_proof_handler_ != nullptr) { - gates_clauses_.Add(truth_tables_clauses_[id]); + ids_for_proof.push_back(id3); + } + } + }; + + // Starts by processing all existing tables. + // + // TODO(user): Since we deal with and_gates differently, do we need to look at + // binary clauses here ? or that is not needed ? I think there is only two + // kind of Boolean function on two inputs (and_gates, with any possible + // negation) and xor_gate. + std::vector ids_for_proof; + for (TruthTableId t_id(0); t_id < truth_tables_inputs_.size(); ++t_id) { + ids_for_proof.clear(); + ids_for_proof.push_back(t_id); + const absl::Span inputs = truth_tables_inputs_[t_id]; + SmallBitset truth_table = truth_tables_bitset_[t_id]; + + // Merge any size-3 table included inside a size-4 gate. + // TODO(user): do it for larger gate too ? + if (inputs.size() == 4) { + merge3_into_4(inputs, truth_table, ids_for_proof); + } + + ++num_tables[inputs.size()]; + const int num_detected = + ProcessTruthTable(inputs, truth_table, ids_for_proof); + num_functions[inputs.size() - 1] += num_detected; + + // If this is not a function and of size 3, lets try to combine it with + // another truth table of size 3 to get a table of size 4. + if (inputs.size() == 3 && num_detected == 0) { + for (int i = 0; i < 3; ++i) { + std::array key{inputs[i != 0 ? 0 : 1], + inputs[i != 2 ? 2 : 1]}; + DCHECK(std::is_sorted(key.begin(), key.end())); + const auto [it, inserted] = + binary_index_map.insert({key, binary_index_map.size()}); + flat_binary_index.push_back(it->second); + flat_table_id.push_back(t_id); + } + } + } + gtl::STLClearObject(&binary_index_map); + + // Detects ITE gates and potentially other 3-gates from a truth table of + // 4-entries formed by two 3-entries table. This just create a 4-entries + // table that will be processed below. + CompactVectorVector candidates; + candidates.ResetFromFlatMapping(flat_binary_index, flat_table_id); + gtl::STLClearObject(&flat_binary_index); + gtl::STLClearObject(&flat_table_id); + int num_combinations = 0; + for (int c = 0; c < candidates.size(); ++c) { + if (candidates[c].size() < 2) continue; + if (candidates[c].size() > 10) continue; // Too many? use heuristic. + + for (int a = 0; a < candidates[c].size(); ++a) { + for (int b = a + 1; b < candidates[c].size(); ++b) { + const absl::Span inputs_a = + truth_tables_inputs_[candidates[c][a]]; + const absl::Span inputs_b = + truth_tables_inputs_[candidates[c][b]]; + + std::array key; + for (int i = 0; i < 3; ++i) key[i] = inputs_a[i]; + key[3] = FindMissing(inputs_a, inputs_b); + CHECK_NE(key[3], kNoBooleanVariable); + + // Add an all allowed entry. + SmallBitset bitmask = GetNumBitsAtOne(4); + CanonicalizeTruthTable(absl::MakeSpan(key), bitmask); + + // If the key was not processed before, process it now. + // Note that an old version created a TruthTableId for it, but that + // waste a lot of space. + // + // On another hand, it is possible we process the same key up to + // 4_choose_2 times, but this is rare... + if (!ids4_.contains(key)) { + ++num_combinations; + ++num_tables[4]; + ids_for_proof.clear(); + merge3_into_4(key, bitmask, ids_for_proof); + num_functions[3] += ProcessTruthTable(key, bitmask, ids_for_proof); + } } } } + timer.AddCounter("combine3", num_combinations); + timer.AddCounter("merges", num_merges); + // Note that we only display non-zero counters. - for (int i = 2; i < 5; ++i) { - timer.AddCounter(absl::StrCat("table", i), num_tables[i]); + for (int i = 2; i < num_tables.size(); ++i) { + timer.AddCounter(absl::StrCat("t", i), num_tables[i]); + } + for (int i = 2; i < num_functions.size(); ++i) { timer.AddCounter(absl::StrCat("fn", i), num_functions[i]); } } @@ -2583,8 +2754,9 @@ class LratGateCongruenceHelper { } // namespace bool GateCongruenceClosure::DoOneRound(bool log_info) { + // TODO(user): Remove this condition, it is possible there are no binary + // and still gates! if (implication_graph_->IsEmpty()) return true; - clause_manager_->DetachAllClauses(); PresolveTimer timer("GateCongruenceClosure", logger_, time_limit_); timer.OverrideLogging(log_info); @@ -2725,9 +2897,9 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { const LiteralIndex y = negate ? rep_b.NegatedIndex() : rep_b.Index(); - // TODO(user): We need to change the union_find algo used to be sure - // the that if rep(x) = y then rep(not(x)) = not(y), otherwise we - // might miss some reductions. + // Because x always refer to a and y to b, this should maintain + // the invariant root(lit) = root(lit.Negated()).Negated(). + // This is checked below. union_find.AddEdge(x.value(), y.value()); const LiteralIndex rep(union_find.FindRoot(y.value())); const LiteralIndex to_merge = rep == x ? y : x; @@ -2758,6 +2930,12 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { in_queue[gate_id.value()] = true; } } + + // Invariant. + CHECK_EQ( + lrat_helper.GetRepresentativeWithProofSupport(rep_a), + lrat_helper.GetRepresentativeWithProofSupport(rep_a.Negated()) + .Negated()); } break; } @@ -2840,8 +3018,8 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { // Generic "short" gates. // We just take the representative and re-canonicalize. absl::Span inputs = gates_inputs_[id]; - CHECK_GE(gates_type_[id], 0); - CHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0); + DCHECK_GE(gates_type_[id], 0); + DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0); for (LiteralIndex& lit_ref : inputs) { lit_ref = lrat_helper.GetRepresentativeWithProofSupport(Literal(lit_ref)) @@ -2853,6 +3031,7 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { if (new_size < inputs.size()) { gates_inputs_.Shrink(id, new_size); } + DCHECK_EQ(gates_type_[id] >> (1 << (inputs.size())), 0); if (new_size == 1) { // We have a function of size 1! This is an equivalence. @@ -2863,7 +3042,6 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { break; } else if (new_size == 0) { // We have a fixed function! Just fix the literal. - CHECK(Literal(gates_target_[id]).IsPositive()); const Literal initial_to_fix = (gates_type_[id] & 1) == 1 ? Literal(gates_target_[id]) : Literal(gates_target_[id]).Negated(); diff --git a/ortools/sat/sat_inprocessing.h b/ortools/sat/sat_inprocessing.h index b1119dfc69..3cd770d4d8 100644 --- a/ortools/sat/sat_inprocessing.h +++ b/ortools/sat/sat_inprocessing.h @@ -500,6 +500,12 @@ class GateCongruenceClosure { // functions of the form one_var = f(other_vars). void ExtractShortGates(PresolveTimer& timer); + // Detects gates encoded in the given truth table, and add them to the set + // of gates. Returns the number of gate detected. + int ProcessTruthTable(absl::Span inputs, + SmallBitset truth_table, + absl::Span ids_for_proof = {}); + // Add a small clause to the corresponding truth table. template void AddToTruthTable(SatClause* clause, @@ -550,6 +556,7 @@ class GateCongruenceClosure { // truth_tables_inputs_, this is a bit wasted but simplify the code. absl::flat_hash_map, TruthTableId> ids3_; absl::flat_hash_map, TruthTableId> ids4_; + absl::flat_hash_map, TruthTableId> ids5_; CompactVectorVector truth_tables_inputs_; util_intops::StrongVector truth_tables_bitset_; CompactVectorVector truth_tables_clauses_; diff --git a/ortools/sat/util.h b/ortools/sat/util.h index deb1e0e9dd..a9082251a7 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -134,6 +134,7 @@ class CompactVectorVector { // Returns the previous size() as this is convenient for how we use it. int Add(absl::Span values); void AppendToLastVector(const V& value); + void AppendToLastVector(absl::Span values); // Hacky: same as Add() but for sat::Literal or any type from which we can get // a value type V via L.Index().value(). @@ -919,6 +920,13 @@ inline void CompactVectorVector::AppendToLastVector(const V& value) { buffer_.push_back(value); } +template +inline void CompactVectorVector::AppendToLastVector( + absl::Span values) { + sizes_.back() += values.size(); + buffer_.insert(buffer_.end(), values.begin(), values.end()); +} + template inline void CompactVectorVector::ReplaceValuesBySmallerSet( K key, absl::Span values) {