tentative variable bound sharing between threads
This commit is contained in:
@@ -26,8 +26,7 @@
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
typedef std::function<int64(RoutingNodeIndex, RoutingNodeIndex)>
|
||||
RoutingNodeEvaluator2;
|
||||
typedef std::function<int64(int, int)> IntPairToLong;
|
||||
|
||||
// Random seed generator.
|
||||
int32 GetSeed(bool deterministic);
|
||||
|
||||
@@ -347,6 +347,7 @@ std::string Summarize(const std::string& input) {
|
||||
struct WorkerInfo {
|
||||
std::string worker_name;
|
||||
WallTimer* global_timer = nullptr;
|
||||
int worker_id = -1;
|
||||
bool parallel_mode = false;
|
||||
};
|
||||
|
||||
@@ -1197,18 +1198,62 @@ void LogNewSolution(const std::string& event_or_solution_count,
|
||||
obj_ub, solution_info);
|
||||
}
|
||||
|
||||
void RegisterObjectiveLowerBoundWatcher(
|
||||
void RegisterVariableBoundsLevelZeroWatcher(
|
||||
const CpModelProto* model_proto,
|
||||
const std::function<void(const CpSolverResponse&)>&
|
||||
external_solution_observer,
|
||||
IntegerVariable objective_var, Model* model) {
|
||||
IntegerVariable objective_var,
|
||||
SharedBoundsManager* shared_bounds_manager, Model* model) {
|
||||
const auto broadcast_lower_bound =
|
||||
[model_proto, external_solution_observer, objective_var,
|
||||
model](const std::vector<IntegerVariable>& modified_vars) {
|
||||
model, shared_bounds_manager](const std::vector<IntegerVariable>& modified_vars) {
|
||||
auto* integer_trail = model->Get<IntegerTrail>();
|
||||
const WorkerInfo* const worker_info = model->GetOrCreate<WorkerInfo>();
|
||||
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
|
||||
|
||||
if (worker_info->parallel_mode) {
|
||||
CHECK(shared_bounds_manager != nullptr);
|
||||
std::vector<int> model_variables;
|
||||
std::vector<int64> new_lower_bounds;
|
||||
std::vector<int64> new_upper_bounds;
|
||||
absl::flat_hash_set<int> visited_variables;
|
||||
for (const IntegerVariable& var : modified_vars) {
|
||||
const IntegerVariable positive_var = PositiveVariable(var);
|
||||
const int model_var =
|
||||
mapping->GetProtoVariableFromIntegerVariable(positive_var);
|
||||
if (gtl::ContainsKey(visited_variables, model_var)) {
|
||||
continue;
|
||||
} else {
|
||||
visited_variables.insert(model_var);
|
||||
}
|
||||
if (model_var == -1) continue;
|
||||
const IntegerVariableProto &var_proto =
|
||||
model_proto->variables(model_var);
|
||||
const int64 new_lb =
|
||||
integer_trail->LevelZeroLowerBound(positive_var).value();
|
||||
const int64 new_ub =
|
||||
integer_trail->LevelZeroUpperBound(positive_var).value();
|
||||
model_variables.push_back(model_var);
|
||||
new_lower_bounds.push_back(new_lb);
|
||||
new_upper_bounds.push_back(new_ub);
|
||||
if (!var_proto.name().empty()) {
|
||||
VLOG(3) << worker_info->worker_name << " write "
|
||||
<< var_proto.name() << "(" << model_var
|
||||
<< ")[" << new_lb << ", " << new_ub << "]";
|
||||
} else {
|
||||
VLOG(3) << worker_info->worker_name << " write anonymous_var("
|
||||
<< model_var << ")[" << new_lb << ", " << new_ub << "]";
|
||||
}
|
||||
}
|
||||
if (!model_variables.empty()) {
|
||||
shared_bounds_manager->ReportPotentialNewBounds(
|
||||
model_variables, new_lower_bounds, new_upper_bounds,
|
||||
worker_info->worker_id, worker_info->worker_name);
|
||||
}
|
||||
}
|
||||
|
||||
const ObjectiveSynchronizationHelper* const helper =
|
||||
model->GetOrCreate<ObjectiveSynchronizationHelper>();
|
||||
const WorkerInfo* const worker_info = model->GetOrCreate<WorkerInfo>();
|
||||
const CpObjectiveProto& obj = model_proto->objective();
|
||||
const double new_best_bound = ScaleObjectiveValue(
|
||||
obj, integer_trail->LevelZeroLowerBound(objective_var).value());
|
||||
@@ -1216,7 +1261,7 @@ void RegisterObjectiveLowerBoundWatcher(
|
||||
const double current_objective_value =
|
||||
helper->get_external_best_objective();
|
||||
|
||||
// TODO(user): Unit test this lambda.
|
||||
// TODO(lperron): Unit test this lambda.
|
||||
if ((helper->scaling_factor >= 0 && // Unset -> = 0.0 -> minimize.
|
||||
new_best_bound > current_best_bound) ||
|
||||
(helper->scaling_factor < 0 &&
|
||||
@@ -1243,6 +1288,43 @@ void RegisterObjectiveLowerBoundWatcher(
|
||||
|
||||
model->GetOrCreate<GenericLiteralWatcher>()
|
||||
->RegisterLevelZeroModifiedVariablesCallback(broadcast_lower_bound);
|
||||
|
||||
if (shared_bounds_manager != nullptr) {
|
||||
const auto& import_lower_bounds = [model_proto, shared_bounds_manager, model]() {
|
||||
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
|
||||
const WorkerInfo* const worker_info = model->GetOrCreate<WorkerInfo>();
|
||||
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
|
||||
CHECK(worker_info->parallel_mode);
|
||||
std::vector<int> model_variables;
|
||||
std::vector<int64> new_lower_bounds;
|
||||
std::vector<int64> new_upper_bounds;
|
||||
shared_bounds_manager->GetChangedBounds(
|
||||
worker_info->worker_id, &model_variables, &new_lower_bounds,
|
||||
&new_upper_bounds);
|
||||
for (int i = 0; i < model_variables.size(); ++i) {
|
||||
// This can happen if a boolean variables is force to have an
|
||||
// integer view in one thread, and not in another thread.
|
||||
if (!mapping->IsInteger(model_variables[i])) continue;
|
||||
const IntegerVariable var = mapping->Integer(model_variables[i]);
|
||||
const IntegerValue new_lb(new_lower_bounds[i]);
|
||||
const IntegerValue new_ub(new_upper_bounds[i]);
|
||||
VLOG(3) << worker_info->worker_name << " read "
|
||||
<< model_proto->variables(model_variables[i]).name() << "["
|
||||
<< new_lb << ", " << new_ub << "]";
|
||||
if (!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, new_lb),
|
||||
{}, {})) {
|
||||
return false;
|
||||
}
|
||||
if (!integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(var, new_ub),
|
||||
{}, {})) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
model->GetOrCreate<GenericLiteralWatcher>()
|
||||
->RegisterLevelZeroImportExternalBoundsCallback(import_lower_bounds);
|
||||
}
|
||||
}
|
||||
// Because we also use this function for postsolve, we call it with
|
||||
// is_real_solve set to true and avoid doing non-useful work in this case.
|
||||
@@ -1250,7 +1332,8 @@ CpSolverResponse SolveCpModelInternal(
|
||||
const CpModelProto& model_proto, bool is_real_solve,
|
||||
const std::function<void(const CpSolverResponse&)>&
|
||||
external_solution_observer,
|
||||
bool watch_objective_lower_bound, Model* model) {
|
||||
bool watch_objective_lower_bound,
|
||||
SharedBoundsManager* shared_bounds_manager, Model* model) {
|
||||
// Timing.
|
||||
WallTimer wall_timer;
|
||||
UserTimer user_timer;
|
||||
@@ -1476,6 +1559,7 @@ CpSolverResponse SolveCpModelInternal(
|
||||
// Detect sequential mode, register callbacks in that case.
|
||||
if (model->Get<WorkerInfo>() == nullptr) {
|
||||
model->GetOrCreate<WorkerInfo>()->global_timer = &wall_timer;
|
||||
model->GetOrCreate<WorkerInfo>()->worker_id = 0;
|
||||
auto* integer_trail = model->Get<IntegerTrail>();
|
||||
const CpObjectiveProto& obj = model_proto.objective();
|
||||
const auto get_objective_value = [&response, integer_trail, &obj,
|
||||
@@ -1498,8 +1582,9 @@ CpSolverResponse SolveCpModelInternal(
|
||||
SetObjectiveSynchronizationFunctions(get_objective_value,
|
||||
get_objective_best_bound, model);
|
||||
}
|
||||
RegisterObjectiveLowerBoundWatcher(&model_proto, external_solution_observer,
|
||||
objective_var, model);
|
||||
RegisterVariableBoundsLevelZeroWatcher(
|
||||
&model_proto, external_solution_observer, objective_var,
|
||||
shared_bounds_manager, model);
|
||||
}
|
||||
|
||||
// Load solution hint.
|
||||
@@ -1666,9 +1751,10 @@ void PostsolveResponse(const CpModelProto& model_proto,
|
||||
params.set_linearization_level(0);
|
||||
postsolve_model.Add(operations_research::sat::NewSatParameters(params));
|
||||
}
|
||||
const CpSolverResponse postsolve_response = SolveCpModelInternal(
|
||||
mapping_proto, false, [](const CpSolverResponse&) {},
|
||||
/*watch_objective_lower_bound=*/false, &postsolve_model);
|
||||
const CpSolverResponse postsolve_response =
|
||||
SolveCpModelInternal(mapping_proto, false, [](const CpSolverResponse&) {},
|
||||
/*watch_objective_lower_bound=*/false,
|
||||
/*shared_bounds_manager=*/nullptr, &postsolve_model);
|
||||
CHECK_EQ(postsolve_response.status(), CpSolverStatus::FEASIBLE);
|
||||
|
||||
// We only copy the solution from the postsolve_response to the response.
|
||||
@@ -1872,7 +1958,8 @@ CpSolverResponse SolveCpModelWithLNS(
|
||||
} else {
|
||||
response =
|
||||
SolveCpModelInternal(model_proto, /*is_real_solve=*/true, observer,
|
||||
/*watch_objective_lower_bound=*/false, model);
|
||||
/*watch_objective_lower_bound=*/false,
|
||||
/*shared_bounds_manager=*/nullptr, model);
|
||||
}
|
||||
if (response.status() != CpSolverStatus::FEASIBLE) {
|
||||
return response;
|
||||
@@ -1960,7 +2047,8 @@ CpSolverResponse SolveCpModelWithLNS(
|
||||
&postsolve_mapping);
|
||||
local_response = SolveCpModelInternal(
|
||||
local_problem, true, [](const CpSolverResponse& response) {},
|
||||
/*watch_objective_lower_bound=*/false, &local_model);
|
||||
/*watch_objective_lower_bound=*/false,
|
||||
/*shared_bounds_manager=*/nullptr, &local_model);
|
||||
PostsolveResponse(model_proto, mapping_proto, postsolve_mapping,
|
||||
&local_response);
|
||||
}
|
||||
@@ -2078,14 +2166,14 @@ CpSolverResponse SolveCpModelParallel(
|
||||
const SatParameters local_params = DiversifySearchParameters(
|
||||
params, model_proto, worker_id, &worker_name);
|
||||
pool.Schedule([&model_proto, stopped, local_params, &best_response,
|
||||
&mutex, worker_name]() {
|
||||
&mutex, worker_name, worker_id]() {
|
||||
Model local_model;
|
||||
local_model.Add(NewSatParameters(local_params));
|
||||
local_model.GetOrCreate<TimeLimit>()->RegisterExternalBooleanAsLimit(
|
||||
stopped);
|
||||
const CpSolverResponse local_response = SolveCpModelInternal(
|
||||
model_proto, true, [](const CpSolverResponse& response) {},
|
||||
/*watch_objective_lower_bound=*/false, &local_model);
|
||||
/*watch_objective_lower_bound=*/false, /*shared_bounds_manager=*/nullptr, &local_model);
|
||||
|
||||
absl::MutexLock lock(&mutex);
|
||||
if (best_response.status() == CpSolverStatus::UNKNOWN) {
|
||||
@@ -2116,7 +2204,10 @@ CpSolverResponse SolveCpModelParallel(
|
||||
return best_response;
|
||||
};
|
||||
|
||||
SharedBoundsManager shared_bounds_manager(num_search_workers,
|
||||
model_proto.variables_size());
|
||||
{
|
||||
|
||||
ThreadPool pool("Parallel_search", num_search_workers);
|
||||
pool.StartWorkers();
|
||||
|
||||
@@ -2153,7 +2244,7 @@ CpSolverResponse SolveCpModelParallel(
|
||||
objective_synchronization, objective_bound_synchronization,
|
||||
stopped, local_params, worker_id, &mutex, &best_response,
|
||||
num_search_workers, random_seed, global_timer,
|
||||
&first_solution_found_or_search_finished, maximize,
|
||||
&first_solution_found_or_search_finished, & shared_bounds_manager, maximize,
|
||||
worker_name]() {
|
||||
Model local_model;
|
||||
local_model.Add(NewSatParameters(local_params));
|
||||
@@ -2165,6 +2256,7 @@ CpSolverResponse SolveCpModelParallel(
|
||||
worker_info->worker_name = worker_name;
|
||||
worker_info->global_timer = global_timer;
|
||||
worker_info->parallel_mode = true;
|
||||
worker_info->worker_id = worker_id;
|
||||
|
||||
SetSynchronizationFunction(std::move(solution_synchronization),
|
||||
&local_model);
|
||||
@@ -2183,7 +2275,7 @@ CpSolverResponse SolveCpModelParallel(
|
||||
} else {
|
||||
thread_response = SolveCpModelInternal(
|
||||
model_proto, true, solution_observer,
|
||||
/*watch_objective_lower_bound=*/true, &local_model);
|
||||
/*watch_objective_lower_bound=*/true, &shared_bounds_manager, &local_model);
|
||||
}
|
||||
|
||||
// Process final solution. Decide which worker has the 'best'
|
||||
@@ -2372,7 +2464,8 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
|
||||
} else {
|
||||
response = SolveCpModelInternal(
|
||||
new_model, /*is_real_solve=*/true, observer_function,
|
||||
/*watch_objective_lower_bound=*/true, model);
|
||||
/*watch_objective_lower_bound=*/true, /*shared_bounds_manager=*/nullptr,
|
||||
model);
|
||||
}
|
||||
|
||||
postprocess_solution(&response);
|
||||
|
||||
@@ -20,9 +20,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "ortools/base/integral_types.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/sat/cp_model.pb.h"
|
||||
#include "ortools/util/bitset.h"
|
||||
#include "ortools/util/sorted_interval_list.h"
|
||||
|
||||
namespace operations_research {
|
||||
@@ -117,6 +119,92 @@ std::vector<int64> AllValuesInDomain(const ProtoWithDomain& proto) {
|
||||
return result;
|
||||
}
|
||||
|
||||
class SharedBoundsManager {
|
||||
public:
|
||||
SharedBoundsManager(int num_workers, int num_variables)
|
||||
: num_workers_(num_workers),
|
||||
num_variables_(num_variables),
|
||||
changed_variables_per_workers_(num_workers),
|
||||
lower_bounds_(num_variables, kint64min),
|
||||
upper_bounds_(num_variables, kint64max) {
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
changed_variables_per_workers_[i].ClearAndResize(num_variables_);
|
||||
}
|
||||
}
|
||||
|
||||
void ReportPotentialNewBounds(const std::vector<int>& variables,
|
||||
const std::vector<int64>& new_lower_bounds,
|
||||
const std::vector<int64>& new_upper_bounds,
|
||||
int worker_id, const std::string& worker_name) {
|
||||
CHECK_EQ(variables.size(), new_lower_bounds.size());
|
||||
CHECK_EQ(variables.size(), new_upper_bounds.size());
|
||||
{
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
int modified_domains = 0;
|
||||
int fixed_domains = 0;
|
||||
for (int i = 0; i < variables.size(); ++i) {
|
||||
const int var = variables[i];
|
||||
if (var >= num_variables_) continue;
|
||||
const int64 new_lb = new_lower_bounds[i];
|
||||
const int64 new_ub = new_upper_bounds[i];
|
||||
CHECK_GE(var, 0);
|
||||
bool changed = false;
|
||||
if (lower_bounds_[var] < new_lb) {
|
||||
changed = true;
|
||||
lower_bounds_[var] = new_lb;
|
||||
}
|
||||
if (upper_bounds_[var] > new_ub) {
|
||||
changed = true;
|
||||
upper_bounds_[var] = new_ub;
|
||||
}
|
||||
if (changed) {
|
||||
if (lower_bounds_[var] == upper_bounds_[var]) {
|
||||
fixed_domains++;
|
||||
} else {
|
||||
modified_domains++;
|
||||
}
|
||||
for (int j = 0; j < num_workers_; ++j) {
|
||||
if (worker_id == j) continue;
|
||||
changed_variables_per_workers_[j].Set(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fixed_domains > 0 || modified_domains > 0) {
|
||||
VLOG(1) << "Worker " << worker_name
|
||||
<< ": fixed domains=" << fixed_domains
|
||||
<< ", modified domains=" << modified_domains << " out of "
|
||||
<< variables.size() << " events";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GetChangedBounds(int worker_id, std::vector<int>* variables,
|
||||
std::vector<int64>* new_lower_bounds,
|
||||
std::vector<int64>* new_upper_bounds) {
|
||||
variables->clear();
|
||||
new_lower_bounds->clear();
|
||||
new_upper_bounds->clear();
|
||||
{
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
for (const int var : changed_variables_per_workers_[worker_id]
|
||||
.PositionsSetAtLeastOnce()) {
|
||||
variables->push_back(var);
|
||||
new_lower_bounds->push_back(lower_bounds_[var]);
|
||||
new_upper_bounds->push_back(upper_bounds_[var]);
|
||||
}
|
||||
changed_variables_per_workers_[worker_id].ClearAll();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const int num_workers_;
|
||||
const int num_variables_;
|
||||
std::vector<SparseBitset<int64>> changed_variables_per_workers_;
|
||||
std::vector<int64> lower_bounds_;
|
||||
std::vector<int64> upper_bounds_;
|
||||
absl::Mutex mutex_;
|
||||
};
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
|
||||
|
||||
@@ -1395,8 +1395,7 @@ void GenericLiteralWatcher::UpdateCallingNeeds(Trail* trail) {
|
||||
}
|
||||
|
||||
if (trail->CurrentDecisionLevel() == 0 &&
|
||||
level_zero_modified_variable_callback_ != nullptr &&
|
||||
!modified_vars_.PositionsSetAtLeastOnce().empty()) {
|
||||
level_zero_modified_variable_callback_ != nullptr) {
|
||||
level_zero_modified_variable_callback_(
|
||||
modified_vars_.PositionsSetAtLeastOnce());
|
||||
}
|
||||
@@ -1408,6 +1407,13 @@ bool GenericLiteralWatcher::Propagate(Trail* trail) {
|
||||
const int level = trail->CurrentDecisionLevel();
|
||||
UpdateCallingNeeds(trail);
|
||||
|
||||
// Checks for external bounds. Usually in the multi-thread context.
|
||||
if (trail->CurrentDecisionLevel() == 0 &&
|
||||
level_zero_import_external_bounds_callback_ != nullptr &&
|
||||
!level_zero_import_external_bounds_callback_()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Note that the priority may be set to -1 inside the loop in order to restart
|
||||
// at zero.
|
||||
int test_limit = 0;
|
||||
|
||||
@@ -971,6 +971,14 @@ class GenericLiteralWatcher : public SatPropagator {
|
||||
level_zero_modified_variable_callback_ = cb;
|
||||
}
|
||||
|
||||
// Sets a callbacks that will be called during the Propagate() method at level 0.
|
||||
//
|
||||
// THis is used to check for external bounds in a parallel context.
|
||||
void RegisterLevelZeroImportExternalBoundsCallback(
|
||||
const std::function<bool()>& cb) {
|
||||
level_zero_import_external_bounds_callback_ = cb;
|
||||
}
|
||||
|
||||
private:
|
||||
// Updates queue_ and in_queue_ with the propagator ids that need to be
|
||||
// called.
|
||||
@@ -1005,6 +1013,7 @@ class GenericLiteralWatcher : public SatPropagator {
|
||||
|
||||
std::function<void(const std::vector<IntegerVariable>&)>
|
||||
level_zero_modified_variable_callback_ = nullptr;
|
||||
std::function<bool()> level_zero_import_external_bounds_callback_ = nullptr;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(GenericLiteralWatcher);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user