From 4dab47eaa63bfa5823be875fa32b071bf6a6d13e Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 15 Dec 2025 13:42:37 +0100 Subject: [PATCH] [CP-SAT] bugfixes --- ortools/sat/cp_model_presolve.cc | 16 +++--- ortools/sat/cp_model_solver.cc | 2 - ortools/sat/python/cp_model_test.py | 28 +++------- ortools/sat/sat_solver.cc | 84 ++++++++++++++++++++--------- ortools/sat/sat_solver.h | 6 ++- ortools/sat/solution_crush.cc | 37 +++++-------- ortools/sat/solution_crush.h | 23 ++++---- ortools/sat/variable_expand.cc | 27 ++++++++-- 8 files changed, 131 insertions(+), 92 deletions(-) diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 9dc1e6b5d5..852305a872 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -14634,6 +14634,16 @@ CpSolverStatus CpModelPresolver::Presolve() { // Sync the domains and initialize the mapping model variables. context_->WriteVariableDomainsToProto(); + + // Some vars may have been fixed by the affine relations. This may can impact + // the objective. Let's re-do the canonicalization. + if (context_->working_model->has_objective()) { + // We re-do a canonicalization with the final linear expression. + if (!context_->CanonicalizeObjective()) return InfeasibleStatus(); + context_->WriteObjectiveToProto(); + } + + // Starts the postsolve mapping model. InitializeMappingModelVariables(context_->AllDomains(), &fixed_postsolve_mapping, context_->mapping_model); @@ -14711,12 +14721,6 @@ CpSolverStatus CpModelPresolver::Presolve() { *postsolve_mapping_ = std::move(new_postsolve_mapping); } - if (context_->working_model->has_objective()) { - // We re-do a canonicalization with the final linear expression. - if (!context_->CanonicalizeObjective()) return InfeasibleStatus(); - context_->WriteObjectiveToProto(); - } - DCHECK(context_->ConstraintVariableUsageIsConsistent()); CanonicalizeRoutesConstraintNodeExpressions(context_); UpdateHintInProto(context_); diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 300dff5d23..a116cb39b6 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -250,7 +250,6 @@ void DumpNoOverlap2dProblem(const ConstraintProto& ct, std::vector sizes_to_render; IntegerValue x = bounding_box.x_min; IntegerValue y = 0; - int i = 0; for (const auto& r : non_fixed_boxes) { sizes_to_render.push_back(Rectangle{ .x_min = x, .x_max = x + r.x_size, .y_min = y, .y_max = y + r.y_size}); @@ -259,7 +258,6 @@ void DumpNoOverlap2dProblem(const ConstraintProto& ct, x = 0; y += r.y_size; } - ++i; } VLOG(3) << "Sizes: " << RenderDot(bounding_box, sizes_to_render); } diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index fc3d13543a..3342d18da4 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -165,19 +165,6 @@ class BestBoundCallback: self.best_bound = bb -class BestBoundTimeCallback: - - def __init__(self) -> None: - self.__last_time: float = 0.0 - - def new_best_bound(self, unused_bb: float): - self.__last_time = time.time() - - @property - def last_time(self) -> float: - return self.__last_time - - class CpModelTest(absltest.TestCase): def tearDown(self) -> None: @@ -264,7 +251,7 @@ class CpModelTest(absltest.TestCase): self.assertEqual(nb.index, -b.index - 1) self.assertRaises(TypeError, x.negated) - def test_issue_4654(self) -> None: + def test_issue4654(self) -> None: model = cp_model.CpModel() x = model.NewIntVar(0, 1, "x") y = model.NewIntVar(0, 2, "y") @@ -2457,18 +2444,15 @@ TRFM""" # Solve. solver = cp_model.CpSolver() - solver.parameters.num_workers = 8 + solver.parameters.num_workers = 1 solver.parameters.max_time_in_seconds = 50 solver.parameters.log_search_progress = True - solution_callback = TimeRecorder() - best_bound_callback = BestBoundTimeCallback() + best_bound_callback = BestBoundCallback() solver.best_bound_callback = best_bound_callback.new_best_bound - status = solver.Solve(model, solution_callback) + status = solver.Solve(model) if status == cp_model.OPTIMAL: - last_activity = max( - best_bound_callback.last_time, solution_callback.last_time - ) - self.assertLess(time.time(), last_activity + 30.0) + # Optimal is 28. The first bound found is 19.0. + self.assertGreaterEqual(best_bound_callback.best_bound, 19.0) def test_issue4434(self) -> None: model = cp_model.CpModel() diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 933ec06920..cc6a95ad65 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -1059,20 +1059,58 @@ void SatSolver::ProcessCurrentConflict( Backtrack(backtrack_level); DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); - // Tricky: in case of propagation not at the right level we might need to - // backjump further. - for (const auto& [id, is_redundant, min_lbd, clause] : delayed_to_add_) { + // Add the conflict here, so we process all "newly learned" clause in the + // same way. + learned_clauses_.push_back({learned_conflict_clause_id, is_redundant, + min_lbd_of_subsumed_clauses, + std::move(learned_conflict_)}); + + // Preprocess the new clauses. + // We might need to backtrack further ! + for (auto& [id, is_redundant, min_lbd, clause] : learned_clauses_) { if (clause.empty()) return (void)SetModelUnsat(); - // TODO(user): just remove redundant literal from learned clauses. This - // should just be better. We just have to deal with the proof correctly. - if (clause.size() == 2 && - binary_implication_graph_->RepresentativeOf(clause[0]) == - binary_implication_graph_->RepresentativeOf(clause[1])) { - Backtrack(0); - break; + // Make sure each clause is "canonicalized" with respect to equivalent + // literals. + // + // TODO(user): Maybe we should do that on each reason before we use them in + // conflict analysis/minimization, but it might be a bit costly. + bool some_change = false; + tmp_clause_ids_.clear(); + for (Literal& lit : clause) { + const Literal rep = binary_implication_graph_->RepresentativeOf(lit); + if (rep != lit) { + some_change = true; + if (lrat_proof_handler_ != nullptr) { + // We need not(rep) => not(lit) for the proof. + tmp_clause_ids_.push_back( + binary_implication_graph_->GetClauseId(lit.Negated(), rep)); + CHECK_NE(tmp_clause_ids_.back(), kNoClauseId) << lit << " " << rep; + } + lit = rep; + } + } + if (some_change) { + gtl::STLSortAndRemoveDuplicates(&clause); + + // This shouldn't happen since it is a new learned clause, otherwise + // something is wrong. + for (int i = 1; i < clause.size(); ++i) { + CHECK_NE(clause[i], clause[i - 1].Negated()) << "trivial new clause?"; + } + + if (lrat_proof_handler_ != nullptr) { + // We need a new clause id for the canonicalized version, and the proof + // for how we derived that canonicalization. + const ClauseId new_id = clause_id_generator_->GetNextId(); + tmp_clause_ids_.push_back(id); + lrat_proof_handler_->AddInferredClause(new_id, clause, tmp_clause_ids_); + id = new_id; + } } + // Tricky: in case of propagation not at the right level we might need to + // backjump further. int num_false = 0; for (const Literal l : clause) { if (Assignment().LiteralIsFalse(l)) ++num_false; @@ -1094,19 +1132,15 @@ void SatSolver::ProcessCurrentConflict( } } - // Add any delayed clause before the final conflict. - for (const auto& [id, is_redundant, min_lbd, clause] : delayed_to_add_) { + // Learn the new clauses. + int best_lbd = std::numeric_limits::max(); + for (const auto& [id, is_redundant, min_lbd, clause] : learned_clauses_) { DCHECK((lrat_proof_handler_ == nullptr) || (id != kNoClauseId)); - AddLearnedClauseAndEnqueueUnitPropagation(id, clause, is_redundant, - min_lbd); + const int lbd = AddLearnedClauseAndEnqueueUnitPropagation( + id, clause, is_redundant, min_lbd); + best_lbd = std::min(best_lbd, lbd); } - - // Create and attach the new learned clause. - const int conflict_lbd = AddLearnedClauseAndEnqueueUnitPropagation( - learned_conflict_clause_id, learned_conflict_, is_redundant, - min_lbd_of_subsumed_clauses); - - restart_->OnConflict(conflict_trail_index, conflict_level, conflict_lbd); + restart_->OnConflict(conflict_trail_index, conflict_level, best_lbd); } namespace { @@ -1128,7 +1162,7 @@ std::pair SatSolver::SubsumptionsInConflictResolution( ClauseId learned_conflict_id, absl::Span conflict, absl::Span reason_used) { CHECK_NE(CurrentDecisionLevel(), 0); - delayed_to_add_.clear(); + learned_clauses_.clear(); // This is used to see if the learned conflict subsumes some clauses. // Note that conflict is not yet in the clauses_propagator_. @@ -1249,7 +1283,7 @@ std::pair SatSolver::SubsumptionsInConflictResolution( // We can only add them after backtracking, since these are currently // conflict. - delayed_to_add_.push_back( + learned_clauses_.push_back( {new_id, new_clause_is_redundant, new_clause_min_lbd, std::vector(subsuming_clauses_[i].begin(), subsuming_clauses_[i].end())}); @@ -1345,8 +1379,8 @@ std::pair SatSolver::SubsumptionsInConflictResolution( } // Also learn the "decision" conflict. - delayed_to_add_.push_back({new_clause_id, decision_is_redundant, - decision_min_lbd, decision_clause}); + learned_clauses_.push_back({new_clause_id, decision_is_redundant, + decision_min_lbd, decision_clause}); } // Sparse clear. diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 7a77f5e225..07ecfff85b 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -879,13 +879,15 @@ class SatSolver { CompactVectorVector subsuming_clauses_; CompactVectorVector subsuming_groups_; - struct DelayedNewClause { + // On each conflict, we learn at least one clause, but depending on the cases, + // we can learn more than one. + struct NewClauses { ClauseId id; bool is_redundant; int min_lbd_of_subsumed_clauses; std::vector clause; }; - std::vector delayed_to_add_; + std::vector learned_clauses_; // When true, temporarily disable the deletion of clauses that are not needed // anymore. This is a hack for TryToMinimizeClause() because we use diff --git a/ortools/sat/solution_crush.cc b/ortools/sat/solution_crush.cc index 64076af3e2..cea47abde0 100644 --- a/ortools/sat/solution_crush.cc +++ b/ortools/sat/solution_crush.cc @@ -236,38 +236,29 @@ void SolutionCrush::SetOrUpdateVarToDomain(int var, const Domain& domain) { } } -void SolutionCrush::SetOrUpdateVarToDomain( - int var, const Domain& domain, - const absl::btree_map& encoding, +void SolutionCrush::SetOrUpdateVarToDomainWithOptionalEscapeValue( + int var, const Domain& reduced_var_domain, std::optional unique_escape_value, - bool push_down_when_repairing_hints) { - DCHECK_EQ(domain.Size(), encoding.size()); + bool push_down_when_not_in_domain, + const absl::btree_map& encoding) { if (!solution_is_loaded_) return; if (HasValue(var)) { const int64_t old_value = GetVarValue(var); - if (domain.Contains(old_value)) return; - int64_t new_value = old_value; - if (unique_escape_value.has_value()) { // Only one escape value. + if (reduced_var_domain.Contains(old_value)) return; + if (unique_escape_value.has_value()) { new_value = unique_escape_value.value(); - } else if (push_down_when_repairing_hints) { - DCHECK_GT(old_value, domain.Min()); - new_value = domain.ValueAtOrBefore(old_value); + } else if (push_down_when_not_in_domain) { + DCHECK_GT(old_value, reduced_var_domain.Min()); + new_value = reduced_var_domain.ValueAtOrBefore(old_value); } else { - new_value = domain.ValueAtOrAfter(old_value); - } - for (const auto [value, lit] : encoding) { - SetLiteralValue(lit, value == new_value); + DCHECK_LT(old_value, reduced_var_domain.Max()); + new_value = reduced_var_domain.ValueAtOrAfter(old_value); } + + SetLiteralValue(encoding.at(new_value), true); + CHECK(!encoding.contains(old_value)); SetVarValue(var, new_value); - VLOG(3) << "SetOrUpdateVarToDomain: " << var << ", old_value: " << old_value - << ", new_value: " << new_value - << ", domain: " << domain.ToString(); - DCHECK(encoding.contains(new_value)) - << "domain: " << domain.ToString() << "old_value: " << old_value - << " new_value: " << new_value; - } else if (domain.IsFixed()) { - SetVarValue(var, domain.FixedValue()); } } diff --git a/ortools/sat/solution_crush.h b/ortools/sat/solution_crush.h index 1c014164f3..484e839568 100644 --- a/ortools/sat/solution_crush.h +++ b/ortools/sat/solution_crush.h @@ -151,15 +151,20 @@ class SolutionCrush { // value. Otherwise does nothing. void SetOrUpdateVarToDomain(int var, const Domain& domain); - // If `var` already has a value, updates it to be within the given domain - // following the given encoding and the status of the variable w.r.t. the - // escape value, and the objective. Otherwise, if the domain is fixed, sets - // the value of `var` to this fixed value. Otherwise does nothing. In the - // process, update the encoding literals to reflect the new value of `var`. - void SetOrUpdateVarToDomain(int var, const Domain& domain, - const absl::btree_map& encoding, - std::optional unique_escape_value, - bool push_down_when_repairing_hints); + // If `var` already has a value, updates it to be within the given domain. + // There are 3 cases to consider: + // 1/ The hinted value is in reduced_var_domain. Nothing to do. + // 2/ The hinted value is not in the domain, and there is an escape value. + // Update the hinted value to the escape value, and update the encoding + // literals to reflect the new value of `var`. + // 3/ The hinted value is not in the domain, and there is no escape value. + // Update the hinted value to be in the domain by pushing it in the given + // direction, and update the encoding literals to reflect the new value + void SetOrUpdateVarToDomainWithOptionalEscapeValue( + int var, const Domain& reduced_var_domain, + std::optional unique_escape_value, + bool push_down_when_not_in_domain, + const absl::btree_map& encoding); // Updates the value of the given literals to false if their current values // are different (or does nothing otherwise). diff --git a/ortools/sat/variable_expand.cc b/ortools/sat/variable_expand.cc index 8e1e333ec7..c068807a4e 100644 --- a/ortools/sat/variable_expand.cc +++ b/ortools/sat/variable_expand.cc @@ -658,11 +658,32 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context, values.CreateAllValueEncodingLiterals(); // Fix the hinted value if needed. + // + // The logic follows the classes of equivalence induced by the value of the + // literals from the enforced linear1 constraining this variable. + // Two values are in the same class if all the literals have the same value. + // + // We have a heuristic method here: + // - If the variable is in the domain, we do nothing. + // - If the variable has only var==value and var!=value encodings. All values + // not touched by these linear1 are equivalent. We will reassign them to the + // unique escape value. + // - If the variable also has var>=value and var<=value encodings, we will + // push the value of the variable to the closest value in the domain in the + // direction of the objective. To this effect, for every contiguous set of + // values not in the set of referenced values. the min of the max of that + // set has been added to the encoded domain, such that the push up or down + // always falls back on an encoded value. + // + // TODO(user): we could optimize this as, for instance, we only need to + // look at values from the order encodings, and not all values when creating + // the equivalence class in the last case. const bool push_down_when_unconstrained = !var_in_objective || var_has_positive_objective_coefficient; - solution_crush.SetOrUpdateVarToDomain( - var, Domain::FromValues(values.encoded_values()), values.encoding(), - values.unique_escape_value(), push_down_when_unconstrained); + solution_crush.SetOrUpdateVarToDomainWithOptionalEscapeValue( + var, Domain::FromValues(values.encoded_values()), + values.unique_escape_value(), push_down_when_unconstrained, + values.encoding()); order.CreateAllOrderEncodingLiterals(values); // Link all Boolean in our linear1 to the encoding literals.