diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index ee8db8a648..8b0a2881f5 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -486,7 +486,6 @@ cc_library( hdrs = ["presolve_context.h"], deps = [ ":cp_model_cc_proto", - ":cp_model_checker", ":cp_model_loader", ":cp_model_mapping", ":cp_model_utils", @@ -511,6 +510,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/numeric:int128", @@ -1010,6 +1010,7 @@ cc_library( ":integer_search", ":linear_programming_constraint", ":model", + ":pseudo_costs", ":sat_base", ":sat_decision", ":sat_parameters_cc_proto", @@ -1032,7 +1033,10 @@ cc_library( srcs = ["pseudo_costs.cc"], hdrs = ["pseudo_costs.h"], deps = [ + ":cp_model_mapping", ":integer", + ":linear_constraint_manager", + ":linear_programming_constraint", ":model", ":sat_base", ":sat_parameters_cc_proto", @@ -1041,6 +1045,8 @@ cc_library( "//ortools/base:strong_vector", "//ortools/util:strong_integers", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], ) diff --git a/ortools/sat/colab/cp_sat.ipynb b/ortools/sat/colab/cp_sat.ipynb index 88650543a1..474fb8e3a3 100644 --- a/ortools/sat/colab/cp_sat.ipynb +++ b/ortools/sat/colab/cp_sat.ipynb @@ -179,6 +179,7 @@ " #\n", "\n", " solver = cp_model.CpSolver()\n", + " solver.parameters.log_search_progress = True\n", " status = solver.Solve(model)\n", "\n", " if status == cp_model.FEASIBLE or status == cp_model.OPTIMAL:\n", @@ -400,6 +401,7 @@ "\n", " # Solve model.\n", " solver = cp_model.CpSolver()\n", + " solver.parameters.log_search_progress = True\n", " solver.Solve(model)\n", "\n", " # Output solution.\n", diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 71fae1e0bb..0cd842070b 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -238,10 +238,12 @@ std::function ConstructUserSearchStrategy( value = -(coeff * ub + offset); break; case DecisionStrategyProto::CHOOSE_MIN_DOMAIN_SIZE: - value = coeff * (ub - lb + 1); + // The size of the domain is not multiplied by the coeff. + value = ub - lb + 1; break; case DecisionStrategyProto::CHOOSE_MAX_DOMAIN_SIZE: - value = -coeff * (ub - lb + 1); + // The size of the domain is not multiplied by the coeff. + value = -(ub - lb + 1); break; default: LOG(FATAL) << "Unknown VariableSelectionStrategy " diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 53bccd1152..7d3d9fcd33 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -200,6 +200,95 @@ std::function MostFractionalHeuristic(Model* model) { }; } +std::function BoolPseudoCostHeuristic(Model* model) { + auto* lp_values = model->GetOrCreate(); + auto* encoder = model->GetOrCreate(); + auto* pseudo_costs = model->GetOrCreate(); + auto* integer_trail = model->GetOrCreate(); + return [lp_values, encoder, pseudo_costs, integer_trail]() { + double best_score = 0.0; + BooleanOrIntegerLiteral decision; + for (IntegerVariable var(0); var < lp_values->size(); var += 2) { + // Only look at non-fixed booleans. + const IntegerValue lb = integer_trail->LowerBound(var); + const IntegerValue ub = integer_trail->UpperBound(var); + if (lb != 0 || ub != 1) continue; + + // Get associated literal. + const LiteralIndex index = + encoder->GetAssociatedLiteral(IntegerLiteral::GreaterOrEqual(var, 1)); + if (index == kNoLiteralIndex) continue; + + const double lp_value = (*lp_values)[var]; + const double score = + pseudo_costs->BoolPseudoCost(Literal(index), lp_value); + if (score > best_score) { + best_score = score; + decision = BooleanOrIntegerLiteral(Literal(index)); + } + } + return decision; + }; +} + +std::function LpPseudoCostHeuristic(Model* model) { + auto* lp_values = model->GetOrCreate(); + auto* integer_trail = model->GetOrCreate(); + auto* pseudo_costs = model->GetOrCreate(); + auto* encoder = model->GetOrCreate(); + return [lp_values, pseudo_costs, integer_trail, encoder, model]() { + double best_score = 0.0; + BooleanOrIntegerLiteral decision; + for (IntegerVariable var(0); var < lp_values->size(); var += 2) { + const IntegerValue lb = integer_trail->LowerBound(var); + const IntegerValue ub = integer_trail->UpperBound(var); + if (lb == ub) continue; + + const double lp_value = (*lp_values)[var]; + const bool is_reliable = pseudo_costs->LpReliability(var) >= 4; + const bool is_integer = std::abs(lp_value - std::round(lp_value)) < 1e-6; + + // When not reliable, we skip integer. + // + // TODO(user): Use strong branching when not reliable. + // TODO(user): do not branch on integer lp? however it seems better to + // do that !? Maybe this is because if it has a high pseudo cost + // average, it is good anyway? + if (!is_reliable && is_integer) continue; + + // For Booleans, for some reason it seems the up-branch first work better? + if (lb == 0 && ub == 1) { + const double score = pseudo_costs->LpPseudoCost(var, lp_value); + if (score > best_score) { + const LiteralIndex index = encoder->GetAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, 1)); + if (index != kNoLiteralIndex) { + best_score = score; + decision = BooleanOrIntegerLiteral(Literal(index)); + } + } + } + + // There are some corner cases if we are at the bound. Note that it is + // important to be in sync with the SplitAroundLpValue() below. + double down_fractionality = lp_value - std::floor(lp_value); + if (lp_value >= ToDouble(ub)) down_fractionality = 1.0; + if (lp_value <= ToDouble(lb)) down_fractionality = 0.0; + const double score = pseudo_costs->LpPseudoCost(var, down_fractionality); + + // We delay to subsequent heuristic if the score is 0.0. + if (score > best_score) { + best_score = score; + + // This choose <= value if possible. + decision = BooleanOrIntegerLiteral(SplitAroundGivenValue( + var, IntegerValue(std::floor(lp_value)), model)); + } + } + return decision; + }; +} + std::function UnassignedVarWithLowestMinAtItsMinHeuristic( const std::vector& vars, Model* model) { @@ -349,8 +438,6 @@ std::function SatSolverHeuristic(Model* model) { }; } -// TODO(user): Do we need a mechanism to reduce the range of possible gaps -// when nothing gets proven? This could be a parameter or some adaptative code. std::function ShaveObjectiveLb(Model* model) { auto* objective_definition = model->GetOrCreate(); const IntegerVariable obj_var = objective_definition->objective_var; @@ -1261,11 +1348,7 @@ IntegerSearchHelper::IntegerSearchHelper(Model* model) product_detector_(model->GetOrCreate()), time_limit_(model->GetOrCreate()), pseudo_costs_(model->GetOrCreate()), - inprocessing_(model->GetOrCreate()) { - // This is needed for recording the pseudo-costs. - const ObjectiveDefinition* objective = model->Get(); - if (objective != nullptr) objective_var_ = objective->objective_var; -} + inprocessing_(model->GetOrCreate()) {} bool IntegerSearchHelper::BeforeTakingDecision() { // If we pushed root level deductions, we restart to incorporate them. @@ -1354,21 +1437,13 @@ bool IntegerSearchHelper::GetDecision( } bool IntegerSearchHelper::TakeDecision(Literal decision) { - // Record the changelist and objective bounds for updating pseudo costs. - const std::vector bound_changes = - pseudo_costs_->GetBoundChanges(decision); - IntegerValue old_obj_lb = kMinIntegerValue; - IntegerValue old_obj_ub = kMaxIntegerValue; - if (objective_var_ != kNoIntegerVariable) { - old_obj_lb = integer_trail_->LowerBound(objective_var_); - old_obj_ub = integer_trail_->UpperBound(objective_var_); - } - const int old_level = sat_solver_->CurrentDecisionLevel(); + pseudo_costs_->BeforeTakingDecision(decision); // Note that kUnsatTrailIndex might also mean ASSUMPTIONS_UNSAT. // // TODO(user): on some problems, this function can be quite long. Expand // so that we can check the time limit at each step? + const int old_level = sat_solver_->CurrentDecisionLevel(); const int index = sat_solver_->EnqueueDecisionAndBackjumpOnConflict(decision); if (index == kUnsatTrailIndex) return false; @@ -1380,14 +1455,8 @@ bool IntegerSearchHelper::TakeDecision(Literal decision) { } // Update the pseudo costs. - if (sat_solver_->CurrentDecisionLevel() > old_level && - objective_var_ != kNoIntegerVariable) { - const IntegerValue new_obj_lb = integer_trail_->LowerBound(objective_var_); - const IntegerValue new_obj_ub = integer_trail_->UpperBound(objective_var_); - const IntegerValue objective_bound_change = - (new_obj_lb - old_obj_lb) + (old_obj_ub - new_obj_ub); - pseudo_costs_->UpdateCost(bound_changes, objective_bound_change); - } + pseudo_costs_->AfterTakingDecision( + /*conflict=*/sat_solver_->CurrentDecisionLevel() <= old_level); sat_solver_->AdvanceDeterministicTime(time_limit_); return sat_solver_->ReapplyAssumptionsIfNeeded(); diff --git a/ortools/sat/integer_search.h b/ortools/sat/integer_search.h index 4cde8794ce..511b3d554d 100644 --- a/ortools/sat/integer_search.h +++ b/ortools/sat/integer_search.h @@ -174,6 +174,12 @@ std::function FirstUnassignedVarAtItsMinHeuristic( // Choose the variable with most fractional LP value. std::function MostFractionalHeuristic(Model* model); +// Variant used for LbTreeSearch experimentation. Note that each decision is in +// O(num_variables), but it is kind of ok with LbTreeSearch as we only call this +// for "new" decision, not when we move around in the tree. +std::function BoolPseudoCostHeuristic(Model* model); +std::function LpPseudoCostHeuristic(Model* model); + // Decision heuristic for SolveIntegerProblemWithLazyEncoding(). Like // FirstUnassignedVarAtItsMinHeuristic() but the function will return the // literal corresponding to the fact that the currently non-assigned variable @@ -309,7 +315,6 @@ class IntegerSearchHelper { TimeLimit* time_limit_; PseudoCosts* pseudo_costs_; Inprocessing* inprocessing_; - IntegerVariable objective_var_ = kNoIntegerVariable; bool must_process_conflict_ = false; }; diff --git a/ortools/sat/lb_tree_search.cc b/ortools/sat/lb_tree_search.cc index d7b127ddb9..db965a3d7d 100644 --- a/ortools/sat/lb_tree_search.cc +++ b/ortools/sat/lb_tree_search.cc @@ -34,6 +34,7 @@ #include "ortools/sat/integer_search.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/model.h" +#include "ortools/sat/pseudo_costs.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_decision.h" #include "ortools/sat/sat_parameters.pb.h" @@ -56,6 +57,7 @@ LbTreeSearch::LbTreeSearch(Model* model) integer_trail_(model->GetOrCreate()), watcher_(model->GetOrCreate()), shared_response_(model->GetOrCreate()), + pseudo_costs_(model->GetOrCreate()), sat_decision_(model->GetOrCreate()), search_helper_(model->GetOrCreate()), parameters_(*model->GetOrCreate()) { @@ -316,11 +318,16 @@ SatSolver::Status LbTreeSearch::Search( // TODO(user): This is slightly different than bumping each time we // push a decision that result in an LB increase. This is also called on // backjump for instance. - if (integer_trail_->LowerBound(objective_var_) > - integer_trail_->LevelZeroLowerBound(objective_var_)) { + const IntegerValue obj_diff = + integer_trail_->LowerBound(objective_var_) - + integer_trail_->LevelZeroLowerBound(objective_var_); + if (obj_diff > 0) { std::vector reason = integer_trail_->ReasonFor(IntegerLiteral::GreaterOrEqual( objective_var_, integer_trail_->LowerBound(objective_var_))); + + // TODO(user): We also need to update pseudo cost on conflict. + pseudo_costs_->UpdateBoolPseudoCosts(reason, obj_diff); sat_decision_->BumpVariableActivities(reason); sat_decision_->UpdateVariableActivityIncrement(); } @@ -516,7 +523,7 @@ SatSolver::Status LbTreeSearch::Search( // basically changes if we take the decision later when we explore the // branch or right now. // - // I feel taking it later is better. It also avoid creating uneeded nodes. + // I feel taking it later is better. It also avoid creating unneeded nodes. // It does change the behavior on a few problem though. For instance on // irp.mps.gz, the search works better without this, whatever the random // seed. Not sure why, maybe it creates more diversity? diff --git a/ortools/sat/lb_tree_search.h b/ortools/sat/lb_tree_search.h index 60d7779b5b..14a56778db 100644 --- a/ortools/sat/lb_tree_search.h +++ b/ortools/sat/lb_tree_search.h @@ -30,6 +30,7 @@ #include "ortools/sat/integer_search.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/model.h" +#include "ortools/sat/pseudo_costs.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_decision.h" #include "ortools/sat/sat_parameters.pb.h" @@ -147,6 +148,7 @@ class LbTreeSearch { IntegerTrail* integer_trail_; GenericLiteralWatcher* watcher_; SharedResponseManager* shared_response_; + PseudoCosts* pseudo_costs_; SatDecisionPolicy* sat_decision_; IntegerSearchHelper* search_helper_; IntegerVariable objective_var_; diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index e30f159553..a9fef9c255 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -37,7 +37,6 @@ #include "ortools/base/mathutil.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_loader.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_utils.h" @@ -1503,7 +1502,6 @@ bool PresolveContext::InsertVarValueEncoding(int literal, int ref, hint_[bool_var] = RefIsPositive(literal) ? hint_value : 1 - hint_value; } } - return true; } diff --git a/ortools/sat/pseudo_costs.cc b/ortools/sat/pseudo_costs.cc index b096152654..92d4ac95ea 100644 --- a/ortools/sat/pseudo_costs.cc +++ b/ortools/sat/pseudo_costs.cc @@ -16,12 +16,18 @@ #include #include #include +#include #include #include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/integer.h" +#include "ortools/sat/linear_constraint_manager.h" +#include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -31,24 +37,154 @@ namespace operations_research { namespace sat { +// We prefer the product to combine the cost of two branches. +double PseudoCosts::CombineCosts(double down_branch, double up_branch) const { + if (true) { + return std::max(1e-6, down_branch) * std::max(1e-6, up_branch); + } else { + const double min_value = std::min(up_branch, down_branch); + const double max_value = std::max(up_branch, down_branch); + const double mu = 1.0 / 6.0; + return (1.0 - mu) * min_value + mu * max_value; + } +} + +std::string PseudoCosts::ObjectiveInfo::DebugString() const { + return absl::StrCat("lb: ", lb, " ub:", ub, " lp_bound:", lp_bound); +} + PseudoCosts::PseudoCosts(Model* model) : parameters_(*model->GetOrCreate()), integer_trail_(model->GetOrCreate()), - encoder_(model->GetOrCreate()) { + encoder_(model->GetOrCreate()), + lp_values_(model->GetOrCreate()), + lps_(model->GetOrCreate()) { const int num_vars = integer_trail_->NumIntegerVariables().value(); pseudo_costs_.resize(num_vars); is_relevant_.resize(num_vars, false); scores_.resize(num_vars, 0.0); + + // If objective_var == kNoIntegerVariable, there is not really any point using + // this class. + auto* objective = model->Get(); + if (objective != nullptr) { + objective_var_ = objective->objective_var; + } } -void PseudoCosts::UpdateCost( - const std::vector& bound_changes, - const IntegerValue obj_bound_improvement) { +PseudoCosts::ObjectiveInfo PseudoCosts::GetCurrentObjectiveInfo() { + ObjectiveInfo result; + if (objective_var_ == kNoIntegerVariable) return result; + + result.lb = integer_trail_->LowerBound(objective_var_); + result.ub = integer_trail_->UpperBound(objective_var_); + + // We sum the objectives over the LP components. + // Note that in practice, when we use the pseudo-costs, there is just one. + result.lp_bound = 0.0; + result.lp_at_optimal = true; + for (const auto* lp : *lps_) { + if (!lp->HasSolution()) result.lp_at_optimal = false; + result.lp_bound += lp->SolutionObjectiveValue(); + } + return result; +} + +void PseudoCosts::BeforeTakingDecision(Literal decision) { + if (objective_var_ == kNoIntegerVariable) return; + saved_info_ = GetCurrentObjectiveInfo(); + bound_changes_ = GetBoundChanges(decision); +} + +double PseudoCosts::LpPseudoCost(IntegerVariable var, + double down_fractionality) const { + const int max_index = std::max(var.value(), NegationOf(var).value()); + if (max_index >= average_unit_objective_increase_.size()) return 0.0; + + const double up_fractionality = 1.0 - down_fractionality; + const double up_branch = + up_fractionality * average_unit_objective_increase_[var].CurrentAverage(); + const double down_branch = + down_fractionality * + average_unit_objective_increase_[NegationOf(var)].CurrentAverage(); + return CombineCosts(down_branch, up_branch); +} + +void PseudoCosts::UpdateBoolPseudoCosts(absl::Span reason, + IntegerValue objective_increase) { + const double relative_increase = + ToDouble(objective_increase) / static_cast(reason.size()); + for (const Literal lit : reason) { + if (lit.Index() >= lit_pseudo_costs_.size()) { + lit_pseudo_costs_.resize(lit.Index() + 1); + } + lit_pseudo_costs_[lit].AddData(relative_increase); + } +} + +double PseudoCosts::BoolPseudoCost(Literal lit, double lp_value) const { + if (lit.Index() >= lit_pseudo_costs_.size()) return 0.0; + + const double down_fractionality = lp_value; + const double up_fractionality = 1.0 - lp_value; + const double up_branch = + up_fractionality * lit_pseudo_costs_[lit].CurrentAverage(); + const double down_branch = + down_fractionality * + lit_pseudo_costs_[lit.NegatedIndex()].CurrentAverage(); + return CombineCosts(down_branch, up_branch); +} + +int PseudoCosts::LpReliability(IntegerVariable var) const { + const int max_index = std::max(var.value(), NegationOf(var).value()); + if (max_index >= average_unit_objective_increase_.size()) return 0; + + return std::min( + average_unit_objective_increase_[var].NumRecords(), + average_unit_objective_increase_[NegationOf(var)].NumRecords()); +} + +void PseudoCosts::AfterTakingDecision(bool conflict) { + if (objective_var_ == kNoIntegerVariable) return; + const ObjectiveInfo new_info = GetCurrentObjectiveInfo(); + + // We store a pseudo cost for this literal. We prefer the pure LP version, but + // revert to integer version if there is no lp. TODO(user): tune that. + // + // We only collect lp increase when the lp is at optimal, otherwise it might + // just be the "artificial" continuing of the current lp solve that create the + // increase. + if (saved_info_.lp_at_optimal) { + // Compute the increase in objective. + const double obj_lp_diff = + std::max(0.0, new_info.lp_bound - saved_info_.lp_bound); + const IntegerValue obj_int_diff = new_info.lb - saved_info_.lb; + double obj_diff = obj_lp_diff > 0.0 ? obj_lp_diff : ToDouble(obj_int_diff); + if (conflict) { + // We count a conflict as a max increase + 1.0 + obj_diff = ToDouble(saved_info_.ub) - ToDouble(saved_info_.lb) + 1.0; + } + + // Update the average unit increases. + for (const auto [var, lb_change, lp_increase] : bound_changes_) { + if (lp_increase < 1e-6) continue; + if (var >= average_unit_objective_increase_.size()) { + average_unit_objective_increase_.resize(var + 1); + } + average_unit_objective_increase_[var].AddData(obj_diff / lp_increase); + } + } + + // TODO(user): Handle this case. + if (conflict) return; + + // We also store one for any associated IntegerVariable. + const IntegerValue obj_bound_improvement = + (new_info.lb - saved_info_.lb) + (saved_info_.ub - new_info.ub); DCHECK_GE(obj_bound_improvement, 0); if (obj_bound_improvement == IntegerValue(0)) return; - const double epsilon = 1e-6; - for (const auto [var, lb_change] : bound_changes) { + for (const auto [var, lb_change, lp_increase] : bound_changes_) { if (lb_change == IntegerValue(0)) continue; if (var >= pseudo_costs_.size()) { @@ -67,9 +203,8 @@ void PseudoCosts::UpdateCost( const int64_t count = pseudo_costs_[positive_var].NumRecords() + pseudo_costs_[negative_var].NumRecords(); if (count >= parameters_.pseudo_cost_reliability_threshold()) { - scores_[positive_var] = std::max(GetCost(positive_var), epsilon) * - std::max(GetCost(negative_var), epsilon); - + scores_[positive_var] = + CombineCosts(GetCost(positive_var), GetCost(negative_var)); if (!is_relevant_[positive_var]) { is_relevant_[positive_var] = true; relevant_variables_.push_back(positive_var); @@ -111,30 +246,32 @@ std::vector PseudoCosts::GetBoundChanges( std::vector bound_changes; for (const IntegerLiteral l : encoder_->GetIntegerLiterals(decision)) { - PseudoCosts::VariableBoundChange var_bound_change; - var_bound_change.var = l.var; - var_bound_change.lower_bound_change = - l.bound - integer_trail_->LowerBound(l.var); - bound_changes.push_back(var_bound_change); + PseudoCosts::VariableBoundChange entry; + entry.var = l.var; + entry.lower_bound_change = l.bound - integer_trail_->LowerBound(l.var); + if (l.var < lp_values_->size()) { + entry.lp_increase = + std::max(0.0, ToDouble(l.bound) - (*lp_values_)[l.var]); + } + bound_changes.push_back(entry); } // NOTE: We ignore literal associated to var != value. for (const auto [var, value] : encoder_->GetEqualityLiterals(decision)) { { - PseudoCosts::VariableBoundChange var_bound_change; - var_bound_change.var = var; - var_bound_change.lower_bound_change = - value - integer_trail_->LowerBound(var); - bound_changes.push_back(var_bound_change); + PseudoCosts::VariableBoundChange entry; + entry.var = var; + entry.lower_bound_change = value - integer_trail_->LowerBound(var); + bound_changes.push_back(entry); } // Also do the negation. { - PseudoCosts::VariableBoundChange var_bound_change; - var_bound_change.var = NegationOf(var); - var_bound_change.lower_bound_change = + PseudoCosts::VariableBoundChange entry; + entry.var = NegationOf(var); + entry.lower_bound_change = (-value) - integer_trail_->LowerBound(NegationOf(var)); - bound_changes.push_back(var_bound_change); + bound_changes.push_back(entry); } } diff --git a/ortools/sat/pseudo_costs.h b/ortools/sat/pseudo_costs.h index c022d2b9ee..b5b950f4cf 100644 --- a/ortools/sat/pseudo_costs.h +++ b/ortools/sat/pseudo_costs.h @@ -20,6 +20,7 @@ #include "ortools/base/logging.h" #include "ortools/base/strong_vector.h" #include "ortools/sat/integer.h" +#include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -33,17 +34,15 @@ namespace sat { // objective bounds per unit change in the variable bounds. class PseudoCosts { public: - // Helper struct to get information relevant for pseudo costs from branching - // decisions. - struct VariableBoundChange { - IntegerVariable var = kNoIntegerVariable; - IntegerValue lower_bound_change = IntegerValue(0); - }; explicit PseudoCosts(Model* model); - // Updates the pseudo costs for the given decision. - void UpdateCost(const std::vector& bound_changes, - IntegerValue obj_bound_improvement); + // This must be called before we are about to branch. + // It will record the current objective bounds. + void BeforeTakingDecision(Literal decision); + + // Updates the pseudo costs for the given decision given to + // BeforeTakingDecision(). + void AfterTakingDecision(bool conflict = false); // Returns the variable with best reliable pseudo cost that is not fixed. IntegerVariable GetBestDecisionVar(); @@ -54,27 +53,73 @@ class PseudoCosts { return pseudo_costs_[var].CurrentAverage(); } - // Returns the number of recordings of given variable. Currently used for - // testing only. - int GetRecordings(IntegerVariable var) const { + // Visible for testing. + // Returns the number of recordings of given variable. + int GetNumRecords(IntegerVariable var) const { CHECK_LT(var, pseudo_costs_.size()); return pseudo_costs_[var].NumRecords(); } - // Returns extracted information to update pseudo costs from the given - // branching decision. + // Alternative pseudo-costs. This relies on the LP more heavily and is more + // in line with what a MIP solver would do. + double LpPseudoCost(IntegerVariable var, double down_fractionality) const; + + // Returns the pseudo cost "reliability". + int LpReliability(IntegerVariable var) const; + + // Experimental alternative pseudo cost based on the explanation for bound + // increases. + void UpdateBoolPseudoCosts(absl::Span reason, + IntegerValue objective_increase); + double BoolPseudoCost(Literal lit, double lp_value) const; + + // Visible for testing. + // Returns the bound delta associated with this decision. + struct VariableBoundChange { + IntegerVariable var = kNoIntegerVariable; + IntegerValue lower_bound_change = IntegerValue(0); + double lp_increase = 0.0; + }; std::vector GetBoundChanges(Literal decision); private: - // Reference of integer trail to access the current bounds of variables. + double CombineCosts(double down_branch, double up_branch) const; + + // Returns the current objective info. + struct ObjectiveInfo { + std::string DebugString() const; + + IntegerValue lb = kMinIntegerValue; + IntegerValue ub = kMaxIntegerValue; + double lp_bound = -std::numeric_limits::infinity(); + bool lp_at_optimal = false; + }; + ObjectiveInfo GetCurrentObjectiveInfo(); + + // Model object. const SatParameters& parameters_; IntegerTrail* integer_trail_; IntegerEncoder* encoder_; + ModelLpValues* lp_values_; + LinearProgrammingConstraintCollection* lps_; + IntegerVariable objective_var_ = kNoIntegerVariable; + // Saved info by BeforeTakingDecision(). + ObjectiveInfo saved_info_; + std::vector bound_changes_; + + // Current IntegerVariable pseudo costs. std::vector relevant_variables_; absl::StrongVector is_relevant_; absl::StrongVector scores_; absl::StrongVector pseudo_costs_; + + // This version is mainly based on the lp relaxation. + absl::StrongVector + average_unit_objective_increase_; + + // This version is based on objective increase explanation. + absl::StrongVector lit_pseudo_costs_; }; } // namespace sat