[CP-SAT] reduce memory usage of integer variables; add option to save model in binary format

This commit is contained in:
Laurent Perron
2023-02-09 16:18:46 -08:00
parent 7172a5650a
commit acda5d226d
9 changed files with 120 additions and 47 deletions

View File

@@ -16,7 +16,6 @@ understand how optimization problems can be modeled using the solver. You can
then solve a model with the functions in
[cp_model_solver.h](../sat/cp_model_solver.h).
## Parameters
* [sat_parameters.proto](../sat/sat_parameters.proto):

View File

@@ -390,10 +390,13 @@ std::function<BooleanOrIntegerLiteral()> InstrumentSearchStrategy(
if (decision.boolean_literal_index != kNoLiteralIndex) {
const Literal l = Literal(decision.boolean_literal_index);
LOG(INFO) << "Boolean decision " << l;
for (const IntegerLiteral i_lit :
model->Get<IntegerEncoder>()->GetAllIntegerLiterals(l)) {
const auto& encoder = model->Get<IntegerEncoder>();
for (const IntegerLiteral i_lit : encoder->GetIntegerLiterals(l)) {
LOG(INFO) << " - associated with " << i_lit;
}
for (const auto [var, value] : encoder->GetEqualityLiterals(l)) {
LOG(INFO) << " - associated with " << var << " == " << value;
}
} else {
LOG(INFO) << "Integer decision " << decision.integer_literal;
}

View File

@@ -115,6 +115,9 @@ ABSL_FLAG(bool, cp_model_dump_models, false,
"format to 'FLAGS_cp_model_dump_prefix'{model|presolved_model|"
"mapping_model}.pb.txt.");
ABSL_FLAG(bool, cp_model_dump_text_proto, true,
"DEBUG ONLY, dump models in text proto instead of binary proto.");
ABSL_FLAG(bool, cp_model_dump_lns, false,
"DEBUG ONLY. When set to true, solve will dump all "
"lns models proto in text format to "
@@ -3348,10 +3351,17 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
#if !defined(__PORTABLE_PLATFORM__)
// Dump initial model?
if (absl::GetFlag(FLAGS_cp_model_dump_models)) {
const std::string file =
absl::StrCat(absl::GetFlag(FLAGS_cp_model_dump_prefix), "model.pb.txt");
LOG(INFO) << "Dumping cp model proto to '" << file << "'.";
CHECK_OK(file::SetTextProto(file, model_proto, file::Defaults()));
if (absl::GetFlag(FLAGS_cp_model_dump_text_proto)) {
const std::string file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "model.pb.txt");
LOG(INFO) << "Dumping cp model text proto to '" << file << "'.";
CHECK_OK(file::SetTextProto(file, model_proto, file::Defaults()));
} else {
const std::string file =
absl::StrCat(absl::GetFlag(FLAGS_cp_model_dump_prefix), "model.bin");
LOG(INFO) << "Dumping cp model binary proto to '" << file << "'.";
CHECK_OK(file::SetBinaryProto(file, model_proto, file::Defaults()));
}
}
#endif // __PORTABLE_PLATFORM__
@@ -3812,17 +3822,32 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
#if !defined(__PORTABLE_PLATFORM__)
if (absl::GetFlag(FLAGS_cp_model_dump_models)) {
const std::string presolved_file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "presolved_model.pb.txt");
LOG(INFO) << "Dumping presolved CpModelProto to '" << presolved_file
<< "'.";
CHECK_OK(file::SetTextProto(presolved_file, new_cp_model_proto,
file::Defaults()));
if (absl::GetFlag(FLAGS_cp_model_dump_text_proto)) {
const std::string presolved_file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "presolved_model.pb.txt");
LOG(INFO) << "Dumping presolved CpModelProto to '" << presolved_file
<< "'.";
CHECK_OK(file::SetTextProto(presolved_file, new_cp_model_proto,
file::Defaults()));
const std::string mapping_file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "mapping_model.pb.txt");
LOG(INFO) << "Dumping mapping CpModelProto to '" << mapping_file << "'.";
CHECK_OK(file::SetTextProto(mapping_file, mapping_proto, file::Defaults()));
const std::string mapping_file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "mapping_model.pb.txt");
LOG(INFO) << "Dumping mapping CpModelProto to '" << mapping_file << "'.";
CHECK_OK(
file::SetTextProto(mapping_file, mapping_proto, file::Defaults()));
} else {
const std::string presolved_file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "presolved_model.bin");
LOG(INFO) << "Dumping presolved CpModelProto to '" << presolved_file
<< "'.";
CHECK_OK(file::SetBinaryProto(presolved_file, new_cp_model_proto,
file::Defaults()));
const std::string mapping_file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "mapping_model.bin");
LOG(INFO) << "Dumping mapping CpModelProto to '" << mapping_file << "'.";
CHECK_OK(
file::SetBinaryProto(mapping_file, mapping_proto, file::Defaults()));
}
// If the model is convertible to a MIP, we dump it too.
//

View File

@@ -13,7 +13,8 @@
set_property(SOURCE sat.i PROPERTY CPLUSPLUS ON)
set_property(SOURCE sat.i PROPERTY SWIG_MODULE_NAME operations_research_sat)
set_property(SOURCE sat.i PROPERTY COMPILE_DEFINITIONS ${OR_TOOLS_COMPILE_DEFINITIONS} ABSL_MUST_USE_RESULT=)
set_property(SOURCE sat.i PROPERTY COMPILE_DEFINITIONS
${OR_TOOLS_COMPILE_DEFINITIONS} ABSL_MUST_USE_RESULT)
set_property(SOURCE sat.i PROPERTY COMPILE_OPTIONS
-namespace ${DOTNET_PROJECT}.Sat
-dllimport google-ortools-native)

View File

@@ -343,6 +343,12 @@ void IntegerEncoder::AssociateToIntegerLiteral(Literal literal,
void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
IntegerVariable var,
IntegerValue value) {
// The function is symmetric and we only deal with positive variable.
if (!VariableIsPositive(var)) {
var = NegationOf(var);
value = -value;
}
// Detect literal view. Note that the same literal can be associated to more
// than one variable, and thus already have a view. We don't change it in
// this case.
@@ -393,8 +399,7 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
equality_by_var_.resize(index.value() + 1);
is_fully_encoded_.resize(index.value() + 1);
}
equality_by_var_[index].push_back(
{VariableIsPositive(var) ? value : -value, literal});
equality_by_var_[index].push_back({value, literal});
// Fix literal for constant domain.
if (value == domain.Min() && value == domain.Max()) {
@@ -429,11 +434,10 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
// Update reverse encoding.
const int new_size = 1 + literal.Index().value();
if (new_size > full_reverse_encoding_.size()) {
full_reverse_encoding_.resize(new_size);
if (new_size > reverse_equality_encoding_.size()) {
reverse_equality_encoding_.resize(new_size);
}
full_reverse_encoding_[literal.Index()].push_back(le);
full_reverse_encoding_[literal.Index()].push_back(ge);
reverse_equality_encoding_[literal.Index()].push_back({var, value});
}
// TODO(user): The hard constraints we add between associated literals seems to
@@ -447,9 +451,6 @@ void IntegerEncoder::HalfAssociateGivenLiteral(IntegerLiteral i_lit,
if (new_size > reverse_encoding_.size()) {
reverse_encoding_.resize(new_size);
}
if (new_size > full_reverse_encoding_.size()) {
full_reverse_encoding_.resize(new_size);
}
// Associate the new literal to i_lit.
if (i_lit.var >= encoding_by_var_.size()) {
@@ -467,7 +468,6 @@ void IntegerEncoder::HalfAssociateGivenLiteral(IntegerLiteral i_lit,
// TODO(user): do that for the other branch too?
reverse_encoding_[literal.Index()].push_back(i_lit);
full_reverse_encoding_[literal.Index()].push_back(i_lit);
} else {
const Literal associated(insert_result.first->second);
if (associated != literal) {

View File

@@ -240,6 +240,8 @@ inline std::ostream& operator<<(std::ostream& os,
}
using InlinedIntegerLiteralVector = absl::InlinedVector<IntegerLiteral, 2>;
using InlinedIntegerValueVector =
absl::InlinedVector<std::pair<IntegerVariable, IntegerValue>, 2>;
// Represents [coeff * variable + constant] or just a [constant].
//
@@ -504,14 +506,27 @@ class IntegerEncoder {
return reverse_encoding_[lit.Index()];
}
// Same as GetIntegerLiterals(), but in addition, if the literal was
// associated to an integer == value, then the returned list will contain both
// (integer >= value) and (integer <= value).
const InlinedIntegerLiteralVector& GetAllIntegerLiterals(Literal lit) const {
if (lit.Index() >= full_reverse_encoding_.size()) {
return empty_integer_literal_vector_;
// Returns the variable == value pairs that were associated with the given
// Literal. Note that only positive IntegerVariable appears here.
const InlinedIntegerValueVector& GetEqualityLiterals(Literal lit) const {
if (lit.Index() >= reverse_equality_encoding_.size()) {
return empty_integer_value_vector_;
}
return full_reverse_encoding_[lit.Index()];
return reverse_equality_encoding_[lit.Index()];
}
// Returns all the variables for which this literal is associated to either
// var >= value or var == value.
const std::vector<IntegerVariable>& GetAllAssociatedVariables(
Literal lit) const {
temp_associated_vars_.clear();
for (const IntegerLiteral l : GetIntegerLiterals(lit)) {
temp_associated_vars_.push_back(l.var);
}
for (const auto [var, value] : GetEqualityLiterals(lit)) {
temp_associated_vars_.push_back(var);
}
return temp_associated_vars_;
}
// This is part of a "hack" to deal with new association involving a fixed
@@ -610,8 +625,13 @@ class IntegerEncoder {
const InlinedIntegerLiteralVector empty_integer_literal_vector_;
absl::StrongVector<LiteralIndex, InlinedIntegerLiteralVector>
reverse_encoding_;
absl::StrongVector<LiteralIndex, InlinedIntegerLiteralVector>
full_reverse_encoding_;
const InlinedIntegerValueVector empty_integer_value_vector_;
absl::StrongVector<LiteralIndex, InlinedIntegerValueVector>
reverse_equality_encoding_;
// Used by GetAllAssociatedVariables().
mutable std::vector<IntegerVariable> temp_associated_vars_;
std::vector<IntegerLiteral> newly_fixed_integer_literals_;
// Store for a given LiteralIndex its IntegerVariable view or kNoLiteralIndex

View File

@@ -243,13 +243,15 @@ std::function<BooleanOrIntegerLiteral()> SequentialValueSelection(
// Boolean case. We try to decode the Boolean decision to see if it is
// associated with an integer variable.
for (const IntegerLiteral l : encoder->GetAllIntegerLiterals(
//
// TODO(user): we will likely stop at the first non-fixed variable.
for (const IntegerVariable var : encoder->GetAllAssociatedVariables(
Literal(current_decision.boolean_literal_index))) {
if (integer_trail->IsCurrentlyIgnored(l.var)) continue;
if (integer_trail->IsCurrentlyIgnored(var)) continue;
// Sequentially try the value selection heuristics.
for (const auto& value_heuristic : value_selection_heuristics) {
const IntegerLiteral decision = value_heuristic(l.var);
const IntegerLiteral decision = value_heuristic(var);
if (decision.IsValid()) return BooleanOrIntegerLiteral(decision);
}
}
@@ -270,6 +272,9 @@ bool LinearizedPartIsLarge(Model* model) {
return (num_integer_variables <= 2 * num_lp_variables);
}
// Note that all these heuristic do not depend on the variable being positive
// or negative.
//
// TODO(user): Experiment more with value selection heuristics.
std::function<BooleanOrIntegerLiteral()> IntegerValueSelectionHeuristic(
std::function<BooleanOrIntegerLiteral()> var_selection_heuristic,
@@ -604,13 +609,13 @@ std::function<BooleanOrIntegerLiteral()> RandomizeOnRestartHeuristic(
}
// Decode the decision and get the variable.
for (const IntegerLiteral l : encoder->GetAllIntegerLiterals(
for (const IntegerVariable var : encoder->GetAllAssociatedVariables(
Literal(current_decision.boolean_literal_index))) {
if (integer_trail->IsCurrentlyIgnored(l.var)) continue;
if (integer_trail->IsCurrentlyIgnored(var)) continue;
// Try the selected policy.
const IntegerLiteral new_decision =
value_selection_heuristics[val_policy_index](l.var);
value_selection_heuristics[val_policy_index](var);
if (new_decision.IsValid()) return BooleanOrIntegerLiteral(new_decision);
}

View File

@@ -13,7 +13,8 @@
set_property(SOURCE sat.i PROPERTY CPLUSPLUS ON)
set_property(SOURCE sat.i PROPERTY SWIG_MODULE_NAME main)
set_property(SOURCE sat.i PROPERTY COMPILE_DEFINITIONS ${OR_TOOLS_COMPILE_DEFINITIONS} ABSL_MUST_USE_RESULT=)
set_property(SOURCE sat.i PROPERTY COMPILE_DEFINITIONS
${OR_TOOLS_COMPILE_DEFINITIONS} ABSL_MUST_USE_RESULT)
set_property(SOURCE sat.i PROPERTY COMPILE_OPTIONS
-package ${JAVA_PACKAGE}.sat)
swig_add_library(jnisat

View File

@@ -106,9 +106,7 @@ std::vector<PseudoCosts::VariableBoundChange> PseudoCosts::GetBoundChanges(
Literal decision) {
std::vector<PseudoCosts::VariableBoundChange> bound_changes;
// NOTE: We ignore negation of equality decisions.
for (const IntegerLiteral l : encoder_->GetAllIntegerLiterals(decision)) {
if (l.var == kNoIntegerVariable) continue;
for (const IntegerLiteral l : encoder_->GetIntegerLiterals(decision)) {
if (integer_trail_->IsCurrentlyIgnored(l.var)) continue;
PseudoCosts::VariableBoundChange var_bound_change;
var_bound_change.var = l.var;
@@ -117,6 +115,27 @@ std::vector<PseudoCosts::VariableBoundChange> PseudoCosts::GetBoundChanges(
bound_changes.push_back(var_bound_change);
}
// NOTE: We ignore literal associated to var != value.
for (const auto [var, value] : encoder_->GetEqualityLiterals(decision)) {
if (integer_trail_->IsCurrentlyIgnored(var)) continue;
{
PseudoCosts::VariableBoundChange var_bound_change;
var_bound_change.var = var;
var_bound_change.lower_bound_change =
value - integer_trail_->LowerBound(var);
bound_changes.push_back(var_bound_change);
}
// Also do the negation.
{
PseudoCosts::VariableBoundChange var_bound_change;
var_bound_change.var = NegationOf(var);
var_bound_change.lower_bound_change =
(-value) - integer_trail_->LowerBound(NegationOf(var));
bound_changes.push_back(var_bound_change);
}
}
return bound_changes;
}