reorganize synchronization code

This commit is contained in:
Laurent Perron
2019-06-25 14:27:13 +02:00
parent b34d5ba4e5
commit fcafb1b0c7
7 changed files with 291 additions and 340 deletions

View File

@@ -86,11 +86,9 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cp_model_cc_proto",
":cp_model_loader",
":cp_model_search",
":cp_model_utils",
":integer",
":integer_search",
":model",
":sat_base",
"//ortools/base",

View File

@@ -1092,6 +1092,230 @@ std::function<SatParameters(Model*)> NewSatParameters(
namespace {
// Registers a callback that will export variables bounds fixed at level 0 of
// the search. This should not be registered to a LNS search.
void RegisterVariableBoundsLevelZeroExport(
const CpModelProto& model_proto, SharedBoundsManager* shared_bounds_manager,
Model* model) {
CHECK(shared_bounds_manager != nullptr);
int saved_trail_index = 0;
const auto broadcast_level_zero_bounds =
[&model_proto, saved_trail_index, model, shared_bounds_manager](
const std::vector<IntegerVariable>& modified_vars) mutable {
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
std::vector<int> model_variables;
std::vector<int64> new_lower_bounds;
std::vector<int64> new_upper_bounds;
absl::flat_hash_set<int> visited_variables;
// Inspect the modified IntegerVariables.
auto* integer_trail = model->Get<IntegerTrail>();
for (const IntegerVariable& var : modified_vars) {
const IntegerVariable positive_var = PositiveVariable(var);
const int model_var =
mapping->GetProtoVariableFromIntegerVariable(positive_var);
if (model_var == -1 || visited_variables.contains(model_var)) {
// TODO(user): I don't think we should see the same model_var twice
// here so maybe we don't need the visited_variables.contains()
// part.
continue;
}
visited_variables.insert(model_var);
const int64 new_lb =
integer_trail->LevelZeroLowerBound(positive_var).value();
const int64 new_ub =
integer_trail->LevelZeroUpperBound(positive_var).value();
// TODO(user): We could imagine an API based on atomic<int64>
// that could preemptively check if this new bounds are improving.
model_variables.push_back(model_var);
new_lower_bounds.push_back(new_lb);
new_upper_bounds.push_back(new_ub);
}
// Inspect the newly modified Booleans.
auto* trail = model->Get<Trail>();
for (; saved_trail_index < trail->Index(); ++saved_trail_index) {
const Literal fixed_literal = (*trail)[saved_trail_index];
const int model_var = mapping->GetProtoVariableFromBooleanVariable(
fixed_literal.Variable());
if (model_var == -1 || visited_variables.contains(model_var)) {
// If the variable is already visited, it should mean that this
// Boolean also has an IntegerVariable view, and we should already
// have set its bound correctly.
continue;
}
visited_variables.insert(model_var);
model_variables.push_back(model_var);
if (fixed_literal.IsPositive()) {
new_lower_bounds.push_back(1);
new_upper_bounds.push_back(1);
} else {
new_lower_bounds.push_back(0);
new_upper_bounds.push_back(0);
}
}
if (!model_variables.empty()) {
const WorkerInfo* const worker_info =
model->GetOrCreate<WorkerInfo>();
shared_bounds_manager->ReportPotentialNewBounds(
model_proto, worker_info->worker_id, worker_info->worker_name,
model_variables, new_lower_bounds, new_upper_bounds);
}
};
model->GetOrCreate<GenericLiteralWatcher>()
->RegisterLevelZeroModifiedVariablesCallback(broadcast_level_zero_bounds);
}
// Registers a callback to import new variables bounds stored in the
// shared_bounds_manager. These bounds are imported at level 0 of the search
// in the linear scan minimize function.
void RegisterVariableBoundsLevelZeroImport(
const CpModelProto& model_proto, SharedBoundsManager* shared_bounds_manager,
Model* model) {
CHECK(shared_bounds_manager != nullptr);
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
const WorkerInfo* const worker_info = model->GetOrCreate<WorkerInfo>();
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
const auto& import_level_zero_bounds = [&model_proto, shared_bounds_manager,
model, integer_trail, worker_info,
mapping]() {
std::vector<int> model_variables;
std::vector<int64> new_lower_bounds;
std::vector<int64> new_upper_bounds;
shared_bounds_manager->GetChangedBounds(worker_info->worker_id,
&model_variables, &new_lower_bounds,
&new_upper_bounds);
bool new_bounds_have_been_imported = false;
for (int i = 0; i < model_variables.size(); ++i) {
const int model_var = model_variables[i];
// This can happen if a boolean variables is forced to have an
// integer view in one thread, and not in another thread.
if (!mapping->IsInteger(model_var)) continue;
const IntegerVariable var = mapping->Integer(model_var);
const IntegerValue new_lb(new_lower_bounds[i]);
const IntegerValue new_ub(new_upper_bounds[i]);
const IntegerValue old_lb = integer_trail->LowerBound(var);
const IntegerValue old_ub = integer_trail->UpperBound(var);
const bool changed_lb = new_lb > old_lb;
const bool changed_ub = new_ub < old_ub;
if (!changed_lb && !changed_ub) continue;
new_bounds_have_been_imported = true;
if (VLOG_IS_ON(3)) {
const IntegerVariableProto& var_proto =
model_proto.variables(model_var);
const std::string& var_name =
var_proto.name().empty()
? absl::StrCat("anonymous_var(", model_var, ")")
: var_proto.name();
LOG(INFO) << " '" << worker_info->worker_name
<< "' imports new bounds for " << var_name << ": from ["
<< old_lb << ", " << old_ub << "] to [" << new_lb << ", "
<< new_ub << "]";
}
if (changed_lb &&
!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, new_lb),
{}, {})) {
return false;
}
if (changed_ub &&
!integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(var, new_ub), {},
{})) {
return false;
}
}
if (new_bounds_have_been_imported &&
!model->GetOrCreate<SatSolver>()->FinishPropagation()) {
return false;
}
return true;
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
import_level_zero_bounds);
}
// Registers a callback that will report improving objective best bound.
// It will be called each time new objective bound are propagated at level zero.
void RegisterObjectiveBestBoundExport(
IntegerVariable objective_var,
SharedResponseManager* shared_response_manager, Model* model) {
std::string worker_name = model->GetOrCreate<WorkerInfo>()->worker_name;
auto* integer_trail = model->Get<IntegerTrail>();
const auto broadcast_objective_lower_bound =
[worker_name, objective_var, integer_trail,
shared_response_manager](const std::vector<IntegerVariable>& unused) {
shared_response_manager->UpdateInnerObjectiveBounds(
worker_name, integer_trail->LevelZeroLowerBound(objective_var),
integer_trail->LevelZeroUpperBound(objective_var));
};
model->GetOrCreate<GenericLiteralWatcher>()
->RegisterLevelZeroModifiedVariablesCallback(
broadcast_objective_lower_bound);
}
// Registers a callback to import new objective bounds. It will be called each
// time the search main loop is back to level zero. Note that it the presence of
// assumptions, this will not happend until the set of assumptions is changed.
void RegisterObjectiveBoundsImport(
SharedResponseManager* shared_response_manager, Model* model) {
auto* solver = model->GetOrCreate<SatSolver>();
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
auto* worker_info = model->GetOrCreate<WorkerInfo>();
auto* objective = model->GetOrCreate<ObjectiveDefinition>();
const auto import_objective_bounds = [solver, integer_trail, worker_info,
objective, shared_response_manager]() {
if (solver->AssumptionLevel() != 0) return true;
bool propagate = false;
const IntegerValue external_lb =
shared_response_manager->GetInnerObjectiveLowerBound();
const IntegerValue current_lb =
integer_trail->LowerBound(objective->objective_var);
if (external_lb > current_lb) {
if (!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(
objective->objective_var, external_lb),
{}, {})) {
return false;
}
propagate = true;
}
const IntegerValue external_ub =
shared_response_manager->GetInnerObjectiveUpperBound();
const IntegerValue current_ub =
integer_trail->UpperBound(objective->objective_var);
if (external_ub < current_ub) {
if (!integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(
objective->objective_var, external_ub),
{}, {})) {
return false;
}
propagate = true;
}
if (!propagate) return true;
VLOG(2) << "'" << worker_info->worker_name
<< "' imports objective bounds: external ["
<< objective->ScaleIntegerObjective(external_lb) << ", "
<< objective->ScaleIntegerObjective(external_ub) << "], current ["
<< objective->ScaleIntegerObjective(current_lb) << ", "
<< objective->ScaleIntegerObjective(current_ub) << "]";
return solver->FinishPropagation();
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
import_objective_bounds);
}
// Loads a CpModelProto inside the given model.
// This should only be called once on a given 'Model' class.
//
@@ -1301,6 +1525,26 @@ void LoadCpModel(const CpModelProto& model_proto,
RegisterObjectiveBoundsImport(shared_response_manager, model);
}
}
// Cache the relavant data for RINS variables.
if (model->Get<SharedRINSNeighborhoodManager>() != nullptr) {
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
auto* lp_dispatcher = model->GetOrCreate<LinearProgrammingDispatcher>();
auto* rins_vars = model->GetOrCreate<RINSVariables>();
IntegerVariable size = integer_trail->NumIntegerVariables();
for (IntegerVariable positive_var(0); positive_var < size;
positive_var += 2) {
RINSVariable rins_var;
rins_var.positive_var = positive_var;
rins_var.model_var =
mapping->GetProtoVariableFromIntegerVariable(positive_var);
rins_var.lp = gtl::FindWithDefault(*lp_dispatcher, positive_var, nullptr);
if (rins_var.lp != nullptr && rins_var.model_var >= 0) {
rins_vars->vars.push_back(rins_var);
}
}
}
}
// Solves an already loaded cp_model_proto.
@@ -1335,26 +1579,6 @@ void SolveLoadedCpModel(const CpModelProto& model_proto,
model_proto, mapping->GetVariableMapping(), fixed_search, model);
}
if (model->Get<SharedRINSNeighborhoodManager>() != nullptr) {
// Cache the relavant data for RINS variables.
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
auto* lp_dispatcher = model->GetOrCreate<LinearProgrammingDispatcher>();
auto* rins_vars = model->GetOrCreate<RINSVariables>();
IntegerVariable size = integer_trail->NumIntegerVariables();
for (IntegerVariable positive_var(0); positive_var < size;
positive_var += 2) {
RINSVariable rins_var;
rins_var.positive_var = positive_var;
rins_var.model_var =
mapping->GetProtoVariableFromIntegerVariable(positive_var);
rins_var.lp = gtl::FindWithDefault(*lp_dispatcher, positive_var, nullptr);
if (rins_var.lp != nullptr && rins_var.model_var >= 0) {
rins_vars->vars.push_back(rins_var);
}
}
}
const auto solution_observer = [&model_proto, &model, &solution_info,
&shared_response_manager]() {
CpSolverResponse response;
@@ -1401,54 +1625,32 @@ void SolveLoadedCpModel(const CpModelProto& model_proto,
const SatSolver::Status status = ResetAndSolveIntegerProblem({}, model);
if (status == SatSolver::Status::FEASIBLE) {
bool hint_is_valid = true;
if (model_proto.has_objective() && parameters.optimize_with_core()) {
// We need to fix the lower bound of the objective variable.
// If linearization_level = 0, the objective_var is not linked with
// the model. We recompute its value from scratch.
int64 objective_value = 0;
for (int i = 0; i < model_proto.objective().vars_size(); ++i) {
objective_value += model_proto.objective().coeffs(i) *
model->Get(LowerBound(mapping->Integer(
model_proto.objective().vars(i))));
const int old_size = solution_info.size();
solution_info += "[hint]";
solution_observer();
solution_info.resize(old_size);
if (!model_proto.has_objective()) {
if (parameters.enumerate_all_solutions()) {
model->Add(
ExcludeCurrentSolutionWithoutIgnoredVariableAndBacktrack());
} else {
return;
}
} else {
// Restrict the objective. Note that since we call the solution observer
// above, the shared_response_manager objective should be up to date.
model->GetOrCreate<SatSolver>()->Backtrack(0);
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
if (!integer_trail->Enqueue(
IntegerLiteral::GreaterOrEqual(objective_var,
IntegerValue(objective_value)),
IntegerLiteral::LowerOrEqual(
objective_var,
shared_response_manager->GetInnerObjectiveUpperBound()),
{}, {})) {
hint_is_valid = false;
model->GetOrCreate<SatSolver>()->Backtrack(0);
}
}
if (hint_is_valid) {
const int old_size = solution_info.size();
solution_info += "[hint]";
solution_observer();
solution_info.resize(old_size);
if (!model_proto.has_objective()) {
if (parameters.enumerate_all_solutions()) {
model->Add(
ExcludeCurrentSolutionWithoutIgnoredVariableAndBacktrack());
} else {
return;
}
} else {
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
IntegerValue current_internal_objective =
integer_trail->LowerBound(objective_var);
// Restrict the objective.
model->GetOrCreate<SatSolver>()->Backtrack(0);
if (!integer_trail->Enqueue(
IntegerLiteral::LowerOrEqual(objective_var,
current_internal_objective - 1),
{}, {})) {
shared_response_manager->NotifyThatImprovingProblemIsInfeasible(
absl::StrCat(solution_info, "[hint]"));
shared_response_manager->SetStatsFromModel(model);
return;
}
shared_response_manager->NotifyThatImprovingProblemIsInfeasible(
absl::StrCat(solution_info, "[hint]"));
shared_response_manager->SetStatsFromModel(model);
return;
}
}
}
@@ -1473,10 +1675,6 @@ void SolveLoadedCpModel(const CpModelProto& model_proto,
}
} else {
// Optimization problem.
const CpObjectiveProto& obj = model_proto.objective();
VLOG(2) << obj.vars_size() << " terms in the proto objective.";
VLOG(2) << "Initial num_bool: " << model->Get<SatSolver>()->NumVariables();
if (parameters.optimize_with_core()) {
std::vector<IntegerVariable> linear_vars;
std::vector<IntegerValue> linear_coeffs;
@@ -2417,6 +2615,8 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
if (log_search) {
LOG(INFO) << absl::StrFormat("*** starting sequential search at %.2fs",
wall_timer.Get());
LOG(INFO) << "Initial num_bool: "
<< model->Get<SatSolver>()->NumVariables();
}
SolveLoadedCpModel(new_cp_model_proto, &shared_response_manager, model);
}

View File

@@ -759,22 +759,5 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding(Model* model) {
return ResetAndSolveIntegerProblem(/*assumptions=*/{}, model);
}
void LogNewSolution(const std::string& event_or_solution_count,
double time_in_seconds, double obj_best, double obj_lb,
double obj_ub, const std::string& solution_info) {
const std::string obj_next =
absl::StrFormat("next:[%.9g,%.9g]", obj_lb, obj_ub);
LOG(INFO) << absl::StrFormat("#%-5s %6.2fs best:%-5.9g %-15s %s",
event_or_solution_count, time_in_seconds,
obj_best, obj_next, solution_info);
}
void LogNewSatSolution(const std::string& event_or_solution_count,
double time_in_seconds,
const std::string& solution_info) {
LOG(INFO) << absl::StrFormat("#%-5s %6.2fs %s", event_or_solution_count,
time_in_seconds, solution_info);
}
} // namespace sat
} // namespace operations_research

View File

@@ -206,16 +206,6 @@ std::vector<std::function<LiteralIndex()>> CompleteHeuristics(
const std::vector<std::function<LiteralIndex()>>& incomplete_heuristics,
const std::function<LiteralIndex()>& completion_heuristic);
// Prints out a new optimization solution in a fixed format.
void LogNewSolution(const std::string& event_or_solution_count,
double time_in_seconds, double obj_best, double obj_lb,
double obj_ub, const std::string& solution_info);
// Prints out a new satisfiability solution in a fixed format.
void LogNewSatSolution(const std::string& event_or_solution_count,
double time_in_seconds,
const std::string& solution_info);
} // namespace sat
} // namespace operations_research

View File

@@ -20,7 +20,7 @@
#include <vector>
#if !defined(__PORTABLE_PLATFORM__)
#include "ortools/base/threadpool.h"
#include "ortools/base/threadpool.h"
#endif // __PORTABLE_PLATFORM__
namespace operations_research {

View File

@@ -16,11 +16,9 @@
#include "absl/container/flat_hash_set.h"
#include "ortools/base/stl_util.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/cp_model_loader.h"
#include "ortools/sat/cp_model_search.h"
#include "ortools/sat/cp_model_utils.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_search.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
@@ -77,6 +75,27 @@ SharedResponseManager::SharedResponseManager(bool log_updates,
wall_timer_(*wall_timer),
solutions_(/*num_solutions_to_keep=*/10) {}
namespace {
void LogNewSolution(const std::string& event_or_solution_count,
double time_in_seconds, double obj_best, double obj_lb,
double obj_ub, const std::string& solution_info) {
const std::string obj_next =
absl::StrFormat("next:[%.9g,%.9g]", obj_lb, obj_ub);
LOG(INFO) << absl::StrFormat("#%-5s %6.2fs best:%-5.9g %-15s %s",
event_or_solution_count, time_in_seconds,
obj_best, obj_next, solution_info);
}
void LogNewSatSolution(const std::string& event_or_solution_count,
double time_in_seconds,
const std::string& solution_info) {
LOG(INFO) << absl::StrFormat("#%-5s %6.2fs %s", event_or_solution_count,
time_in_seconds, solution_info);
}
} // namespace
void SharedResponseManager::UpdateInnerObjectiveBounds(
const std::string& worker_info, IntegerValue lb, IntegerValue ub) {
absl::MutexLock mutex_lock(&mutex_);
@@ -272,7 +291,7 @@ void SharedResponseManager::NewSolution(const CpSolverResponse& response,
std::string solution_info = response.solution_info();
if (model != nullptr) {
absl::StrAppend(&solution_info,
" num_bool:", model->Get<SatSolver>()->NumVariables());
" num_bool:", model->Get<Trail>()->NumVariables());
}
if (model_proto_.has_objective()) {
@@ -407,219 +426,5 @@ void SharedBoundsManager::GetChangedBounds(
}
}
void RegisterVariableBoundsLevelZeroExport(
const CpModelProto& model_proto, SharedBoundsManager* shared_bounds_manager,
Model* model) {
CHECK(shared_bounds_manager != nullptr);
int saved_trail_index = 0;
const auto broadcast_level_zero_bounds =
[&model_proto, saved_trail_index, model, shared_bounds_manager](
const std::vector<IntegerVariable>& modified_vars) mutable {
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
std::vector<int> model_variables;
std::vector<int64> new_lower_bounds;
std::vector<int64> new_upper_bounds;
absl::flat_hash_set<int> visited_variables;
// Inspect the modified IntegerVariables.
auto* integer_trail = model->Get<IntegerTrail>();
for (const IntegerVariable& var : modified_vars) {
const IntegerVariable positive_var = PositiveVariable(var);
const int model_var =
mapping->GetProtoVariableFromIntegerVariable(positive_var);
if (model_var == -1 || visited_variables.contains(model_var)) {
// TODO(user): I don't think we should see the same model_var twice
// here so maybe we don't need the visited_variables.contains()
// part.
continue;
}
visited_variables.insert(model_var);
const int64 new_lb =
integer_trail->LevelZeroLowerBound(positive_var).value();
const int64 new_ub =
integer_trail->LevelZeroUpperBound(positive_var).value();
// TODO(user): We could imagine an API based on atomic<int64>
// that could preemptively check if this new bounds are improving.
model_variables.push_back(model_var);
new_lower_bounds.push_back(new_lb);
new_upper_bounds.push_back(new_ub);
}
// Inspect the newly modified Booleans.
auto* trail = model->Get<Trail>();
for (; saved_trail_index < trail->Index(); ++saved_trail_index) {
const Literal fixed_literal = (*trail)[saved_trail_index];
const int model_var = mapping->GetProtoVariableFromBooleanVariable(
fixed_literal.Variable());
if (model_var == -1 || visited_variables.contains(model_var)) {
// If the variable is already visited, it should mean that this
// Boolean also has an IntegerVariable view, and we should already
// have set its bound correctly.
continue;
}
visited_variables.insert(model_var);
model_variables.push_back(model_var);
if (fixed_literal.IsPositive()) {
new_lower_bounds.push_back(1);
new_upper_bounds.push_back(1);
} else {
new_lower_bounds.push_back(0);
new_upper_bounds.push_back(0);
}
}
if (!model_variables.empty()) {
const WorkerInfo* const worker_info =
model->GetOrCreate<WorkerInfo>();
shared_bounds_manager->ReportPotentialNewBounds(
model_proto, worker_info->worker_id, worker_info->worker_name,
model_variables, new_lower_bounds, new_upper_bounds);
}
};
model->GetOrCreate<GenericLiteralWatcher>()
->RegisterLevelZeroModifiedVariablesCallback(broadcast_level_zero_bounds);
}
void RegisterVariableBoundsLevelZeroImport(
const CpModelProto& model_proto, SharedBoundsManager* shared_bounds_manager,
Model* model) {
CHECK(shared_bounds_manager != nullptr);
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
const WorkerInfo* const worker_info = model->GetOrCreate<WorkerInfo>();
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
const auto& import_level_zero_bounds = [&model_proto, shared_bounds_manager,
model, integer_trail, worker_info,
mapping]() {
std::vector<int> model_variables;
std::vector<int64> new_lower_bounds;
std::vector<int64> new_upper_bounds;
shared_bounds_manager->GetChangedBounds(worker_info->worker_id,
&model_variables, &new_lower_bounds,
&new_upper_bounds);
bool new_bounds_have_been_imported = false;
for (int i = 0; i < model_variables.size(); ++i) {
const int model_var = model_variables[i];
// This can happen if a boolean variables is forced to have an
// integer view in one thread, and not in another thread.
if (!mapping->IsInteger(model_var)) continue;
const IntegerVariable var = mapping->Integer(model_var);
const IntegerValue new_lb(new_lower_bounds[i]);
const IntegerValue new_ub(new_upper_bounds[i]);
const IntegerValue old_lb = integer_trail->LowerBound(var);
const IntegerValue old_ub = integer_trail->UpperBound(var);
const bool changed_lb = new_lb > old_lb;
const bool changed_ub = new_ub < old_ub;
if (!changed_lb && !changed_ub) continue;
new_bounds_have_been_imported = true;
if (VLOG_IS_ON(2)) {
const IntegerVariableProto& var_proto =
model_proto.variables(model_var);
const std::string& var_name =
var_proto.name().empty()
? absl::StrCat("anonymous_var(", model_var, ")")
: var_proto.name();
LOG(INFO) << " '" << worker_info->worker_name
<< "' imports new bounds for " << var_name << ": from ["
<< old_lb << ", " << old_ub << "] to [" << new_lb << ", "
<< new_ub << "]";
}
if (changed_lb &&
!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, new_lb),
{}, {})) {
return false;
}
if (changed_ub &&
!integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(var, new_ub), {},
{})) {
return false;
}
}
if (new_bounds_have_been_imported &&
!model->GetOrCreate<SatSolver>()->FinishPropagation()) {
return false;
}
return true;
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
import_level_zero_bounds);
}
void RegisterObjectiveBestBoundExport(
IntegerVariable objective_var,
SharedResponseManager* shared_response_manager, Model* model) {
std::string worker_name = model->GetOrCreate<WorkerInfo>()->worker_name;
auto* integer_trail = model->Get<IntegerTrail>();
const auto broadcast_objective_lower_bound =
[worker_name, objective_var, integer_trail,
shared_response_manager](const std::vector<IntegerVariable>& unused) {
shared_response_manager->UpdateInnerObjectiveBounds(
worker_name, integer_trail->LevelZeroLowerBound(objective_var),
integer_trail->LevelZeroUpperBound(objective_var));
};
model->GetOrCreate<GenericLiteralWatcher>()
->RegisterLevelZeroModifiedVariablesCallback(
broadcast_objective_lower_bound);
}
void RegisterObjectiveBoundsImport(
SharedResponseManager* shared_response_manager, Model* model) {
auto* solver = model->GetOrCreate<SatSolver>();
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
auto* worker_info = model->GetOrCreate<WorkerInfo>();
auto* objective = model->GetOrCreate<ObjectiveDefinition>();
const auto import_objective_bounds = [solver, integer_trail, worker_info,
objective, shared_response_manager]() {
if (solver->AssumptionLevel() != 0) return true;
bool propagate = false;
const IntegerValue external_lb =
shared_response_manager->GetInnerObjectiveLowerBound();
const IntegerValue current_lb =
integer_trail->LowerBound(objective->objective_var);
if (external_lb > current_lb) {
if (!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(
objective->objective_var, external_lb),
{}, {})) {
return false;
}
propagate = true;
}
const IntegerValue external_ub =
shared_response_manager->GetInnerObjectiveUpperBound();
const IntegerValue current_ub =
integer_trail->UpperBound(objective->objective_var);
if (external_ub < current_ub) {
if (!integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(
objective->objective_var, external_ub),
{}, {})) {
return false;
}
propagate = true;
}
if (!propagate) return true;
VLOG(1) << "'" << worker_info->worker_name
<< "' imports objective bounds: external ["
<< objective->ScaleIntegerObjective(external_lb) << ", "
<< objective->ScaleIntegerObjective(external_ub) << "], current ["
<< objective->ScaleIntegerObjective(current_lb) << ", "
<< objective->ScaleIntegerObjective(current_ub) << "]";
return solver->FinishPropagation();
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
import_objective_bounds);
}
} // namespace sat
} // namespace operations_research

View File

@@ -211,31 +211,6 @@ class SharedBoundsManager {
absl::Mutex mutex_;
};
// Registers a callback to import new variables bounds stored in the
// shared_bounds_manager. These bounds are imported at level 0 of the search
// in the linear scan minimize function.
void RegisterVariableBoundsLevelZeroImport(
const CpModelProto& model_proto, SharedBoundsManager* shared_bounds_manager,
Model* model);
// Registers a callback that will export variables bounds fixed at level 0 of
// the search. This should not be registered to a LNS search.
void RegisterVariableBoundsLevelZeroExport(
const CpModelProto& model_proto, SharedBoundsManager* shared_bounds_manager,
Model* model);
// Registers a callback to import new objective bounds. It will be called each
// time the search main loop is back to level zero. Note that it the presence of
// assumptions, this will not happend until the set of assumptions is changed.
void RegisterObjectiveBoundsImport(
SharedResponseManager* shared_response_manager, Model* model);
// Registers a callback that will report improving objective best bound.
// It will be called each time new objective bound are propagated at level zero.
void RegisterObjectiveBestBoundExport(
IntegerVariable objective_var,
SharedResponseManager* shared_response_manager, Model* model);
// Stores information on the worker in the parallel context.
struct WorkerInfo {
std::string worker_name;