From 4db23df64aa344037cf6c1931a057dd216f3586d Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sun, 24 Nov 2024 18:36:01 +0100 Subject: [PATCH] [CP-SAT] missing std includes; fix bug in work sharing --- ortools/sat/cp_model_solver_helpers.h | 1 + ortools/sat/cuts.h | 3 +++ ortools/sat/disjunctive.h | 2 ++ ortools/sat/precedences.h | 1 + ortools/sat/util.h | 1 + ortools/sat/work_assignment.cc | 14 +++++++------- ortools/sat/work_assignment.h | 20 +++++++++++++------- ortools/sat/work_assignment_test.cc | 2 +- 8 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index 915bd91c17..d3fe19b03c 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "absl/flags/declare.h" diff --git a/ortools/sat/cuts.h b/ortools/sat/cuts.h index 28359dc474..cc0ec7dbfb 100644 --- a/ortools/sat/cuts.h +++ b/ortools/sat/cuts.h @@ -18,9 +18,12 @@ #include #include +#include +#include #include #include #include +#include #include #include diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index efe214257e..70812a4c0a 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include "absl/types/span.h" diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 2e984f41cf..2ac790a13f 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 29a35b7ac3..9098f8c2a8 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index 0d6b89ab59..839a09ebff 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -159,6 +159,8 @@ std::optional ProtoLiteral::EncodeLiteral( return result; } +ProtoTrail::ProtoTrail() { target_phase_.reserve(kMaxPhaseSize); } + void ProtoTrail::PushLevel(const ProtoLiteral& decision, IntegerValue objective_lb, int node_id) { CHECK_GT(node_id, 0); @@ -791,8 +793,8 @@ bool SharedTreeWorker::SyncWithSharedTree() { if (ShouldReplaceSubtree()) { ++num_trees_; VLOG(2) << parameters_->name() << " acquiring tree #" << num_trees_ - << " after " << num_restarts_ - tree_assignment_restart_ - << " restarts prev depth: " << assigned_tree_.MaxLevel() + << " after " << restart_policy_->NumRestarts() << " restarts" + << " prev depth: " << assigned_tree_.MaxLevel() << " target: " << assigned_tree_lbds_.WindowAverage() << " lbd: " << restart_policy_->LbdAverageSinceReset(); if (parameters_->shared_tree_worker_enable_phase_sharing() && @@ -804,11 +806,10 @@ bool SharedTreeWorker::SyncWithSharedTree() { // workers. auto encoded = ProtoLiteral::EncodeLiteral(lit, mapping_); if (!encoded.has_value()) continue; - assigned_tree_.SetPhase(*encoded); + if (!assigned_tree_.AddPhase(*encoded)) break; } } manager_->ReplaceTree(assigned_tree_); - tree_assignment_restart_ = num_restarts_; assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset()); restart_policy_->Reset(); if (parameters_->shared_tree_worker_enable_phase_sharing()) { @@ -854,9 +855,8 @@ SatSolver::Status SharedTreeWorker::Search( return sat_solver_->UnsatStatus(); } if (heuristics_->restart_policies[heuristics_->policy_index]()) { - ++num_restarts_; - heuristics_->policy_index = - num_restarts_ % heuristics_->decision_policies.size(); + heuristics_->policy_index = restart_policy_->NumRestarts() % + heuristics_->decision_policies.size(); sat_solver_->Backtrack(0); } if (!SyncWithLocalTrail()) return sat_solver_->UnsatStatus(); diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index 6f5119b9e9..162b2ec648 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -102,6 +102,8 @@ class ProtoLiteral { // implications may be propagated. class ProtoTrail { public: + ProtoTrail(); + // Adds a new assigned level to the trail. void PushLevel(const ProtoLiteral& decision, IntegerValue objective_lb, int node_id); @@ -147,18 +149,25 @@ class ProtoTrail { const std::vector& TargetPhase() const { return target_phase_; } void ClearTargetPhase() { target_phase_.clear(); } - void SetPhase(const ProtoLiteral& lit) { - if (implication_level_.contains(lit)) return; - target_phase_.push_back(lit); + // Appends a literal to the target phase, returns false if the phase is full. + bool AddPhase(const ProtoLiteral& lit) { + if (target_phase_.size() >= kMaxPhaseSize) return false; + if (!implication_level_.contains(lit)) { + target_phase_.push_back(lit); + } + return true; } void SetTargetPhase(absl::Span phase) { ClearTargetPhase(); for (const ProtoLiteral& lit : phase) { - SetPhase(lit); + if (!AddPhase(lit)) break; } } private: + // 256 ProtoLiterals take up 4KiB + static constexpr int kMaxPhaseSize = 256; + std::vector& MutableImplications(int level) { return implications_[level - 1]; } @@ -335,14 +344,11 @@ class SharedTreeWorker { LevelZeroCallbackHelper* level_zero_callbacks_; RevIntRepository* reversible_int_repository_; - int64_t num_restarts_ = 0; int64_t num_trees_ = 0; ProtoTrail assigned_tree_; std::vector assigned_tree_literals_; std::vector> assigned_tree_implications_; - // How many restarts had happened when the current tree was assigned? - int64_t tree_assignment_restart_ = -1; // True if the last decision may split the assigned tree and has not yet been // proposed to the SharedTreeManager. diff --git a/ortools/sat/work_assignment_test.cc b/ortools/sat/work_assignment_test.cc index 6bc2e98d2d..fbe3f135f9 100644 --- a/ortools/sat/work_assignment_test.cc +++ b/ortools/sat/work_assignment_test.cc @@ -562,7 +562,7 @@ TEST(SharedTreeManagerTest, TrailSharing) { trail1.AddImplication(1, ProtoLiteral(1, 1)); trail1.AddImplication(1, ProtoLiteral(1, 3)); shared_tree_manager->SyncTree(trail1); - trail1.SetPhase(ProtoLiteral(2, 1)); + trail1.AddPhase(ProtoLiteral(2, 1)); shared_tree_manager->ReplaceTree(trail1); shared_tree_manager->ReplaceTree(trail2);