[CP-SAT] lots of bugfixes, improvement to primary variables heuristics; cleanup scheduling cuts code

This commit is contained in:
Laurent Perron
2025-05-14 13:33:20 +02:00
parent edb5359fd9
commit 0a5e8db6af
25 changed files with 1088 additions and 746 deletions

View File

@@ -726,6 +726,7 @@ cc_library(
":util",
":work_assignment",
"//ortools/base",
"//ortools/base:file",
"//ortools/base:status_macros",
"//ortools/base:strong_vector",
"//ortools/base:threadpool",
@@ -1809,7 +1810,6 @@ cc_library(
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/log:vlog_is_on",
"@abseil-cpp//absl/meta:type_traits",
"@abseil-cpp//absl/random:distributions",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/types:span",
@@ -1935,6 +1935,7 @@ cc_library(
":linear_constraint",
":model",
":no_overlap_2d_helper",
":precedences",
":sat_base",
":sat_solver",
":scheduling_helpers",
@@ -1942,7 +1943,6 @@ cc_library(
"//ortools/util:strong_integers",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/meta:type_traits",
"@abseil-cpp//absl/types:span",
],
)
@@ -1976,13 +1976,11 @@ cc_library(
":sat_base",
":sat_solver",
"//ortools/base",
"//ortools/base:strong_vector",
"//ortools/util:bitset",
"//ortools/util:sort",
"//ortools/util:strong_integers",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/meta:type_traits",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/types:span",
],
@@ -2891,10 +2889,12 @@ cc_library(
":cuts",
":integer",
":integer_base",
":intervals",
":linear_constraint",
":linear_constraint_manager",
":model",
":sat_base",
":sat_solver",
":scheduling_helpers",
":util",
"//ortools/base",
@@ -2906,7 +2906,6 @@ cc_library(
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/container:btree",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/types:span",
@@ -2929,6 +2928,7 @@ cc_test(
":model",
":sat_base",
":scheduling_cuts",
":scheduling_helpers",
"//ortools/base:gmock_main",
"//ortools/base:strong_vector",
"//ortools/util:strong_integers",
@@ -3542,13 +3542,11 @@ cc_library(
":synchronization",
":timetable",
":util",
"//ortools/base:stl_util",
"//ortools/util:bitset",
"//ortools/util:saturated_arithmetic",
"//ortools/util:strong_integers",
"//ortools/util:time_limit",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/container:inlined_vector",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/log:vlog_is_on",

View File

@@ -27,15 +27,10 @@
#include <string>
#include <string_view>
#include <thread>
#include <type_traits>
#include <utility>
#include <vector>
#include "ortools/base/logging.h"
#include "ortools/base/timer.h"
#if !defined(__PORTABLE_PLATFORM__)
#include "ortools/base/helpers.h"
#include "ortools/base/options.h"
#endif // __PORTABLE_PLATFORM__
#include "absl/base/thread_annotations.h"
#include "absl/container/btree_map.h"
#include "absl/container/btree_set.h"
@@ -54,6 +49,10 @@
#include "absl/types/span.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/text_format.h"
#include "ortools/base/helpers.h"
#include "ortools/base/logging.h"
#include "ortools/base/options.h"
#include "ortools/base/timer.h"
#include "ortools/port/proto_utils.h"
#include "ortools/sat/combine_solutions.h"
#include "ortools/sat/cp_model.pb.h"
@@ -92,9 +91,9 @@
#include "ortools/sat/work_assignment.h"
#include "ortools/util/logging.h"
#include "ortools/util/random_engine.h"
#if !defined(__PORTABLE_PLATFORM__)
#if !defined(__EMBEDDED_PLATFORM__)
#include "ortools/util/sigint.h"
#endif // __PORTABLE_PLATFORM__
#endif // __EMBEDDED_PLATFORM__
#include "ortools/base/version.h"
#include "ortools/util/sorted_interval_list.h"
#include "ortools/util/time_limit.h"
@@ -1208,7 +1207,7 @@ class FullProblemSolver : public SubSolver {
bool previous_task_is_completed_ ABSL_GUARDED_BY(mutex_) = true;
};
#if !defined(__PORTABLE_PLATFORM__)
#if !defined(__EMBEDDED_PLATFORM__)
class FeasibilityPumpSolver : public SubSolver {
public:
@@ -1398,7 +1397,7 @@ class LnsSolver : public SubSolver {
break;
}
const std::string_view search_info =
absl::StripPrefix(std::string_view(local_params.name()), "lns_");
absl::StripPrefix(absl::string_view(local_params.name()), "lns_");
local_params.set_max_deterministic_time(data.deterministic_limit);
std::string source_info =
@@ -2218,7 +2217,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
LaunchSubsolvers(params, shared, subsolvers, name_filter.AllIgnored());
}
#endif // __PORTABLE_PLATFORM__
#endif // !defined(__EMBEDDED_PLATFORM__)
// If the option use_sat_inprocessing is true, then before post-solving a
// solution, we need to make sure we add any new clause required for postsolving
@@ -2263,18 +2262,27 @@ std::function<void(Model*)> NewBestBoundCallback(
};
}
#if !defined(__PORTABLE_PLATFORM__)
namespace {
template <typename T>
void ParseFromStringOrDie(absl::string_view str, T* proto) {
if constexpr (std::is_base_of_v<google::protobuf::Message, T>) {
CHECK(google::protobuf::TextFormat::ParseFromString(str, proto)) << str;
} else {
LOG(FATAL) << "Calling NewSatParameters() with a textual proto is not "
"supported when using Lite Protobuf.";
}
}
} // namespace
// TODO(user): Support it on android.
std::function<SatParameters(Model*)> NewSatParameters(
const std::string& params) {
sat::SatParameters parameters;
if (!params.empty()) {
CHECK(google::protobuf::TextFormat::ParseFromString(params, &parameters))
<< params;
ParseFromStringOrDie<SatParameters>(params, &parameters);
}
return NewSatParameters(parameters);
}
#endif // __PORTABLE_PLATFORM__
std::function<SatParameters(Model*)> NewSatParameters(
const sat::SatParameters& parameters) {
@@ -2337,15 +2345,15 @@ void RegisterSearchStatisticCallback(Model* global_model) {
}
void MergeParamsWithFlagsAndDefaults(SatParameters* params) {
#if !defined(__PORTABLE_PLATFORM__)
// Override parameters?
if (!absl::GetFlag(FLAGS_cp_model_params).empty()) {
SatParameters flag_params;
CHECK(google::protobuf::TextFormat::ParseFromString(
absl::GetFlag(FLAGS_cp_model_params), &flag_params));
params->MergeFrom(flag_params);
if constexpr (std::is_base_of_v<google::protobuf::Message, SatParameters>) {
// Override parameters?
if (!absl::GetFlag(FLAGS_cp_model_params).empty()) {
SatParameters flag_params;
ParseFromStringOrDie<SatParameters>(absl::GetFlag(FLAGS_cp_model_params),
&flag_params);
params->MergeFrom(flag_params);
}
}
#endif // __PORTABLE_PLATFORM__
}
} // namespace
@@ -2356,19 +2364,19 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
wall_timer->Start();
user_timer->Start();
#if !defined(__PORTABLE_PLATFORM__)
// Dump initial model?
if (absl::GetFlag(FLAGS_cp_model_dump_models)) {
DumpModelProto(model_proto, "model");
}
if (absl::GetFlag(FLAGS_cp_model_export_model)) {
if (model_proto.name().empty()) {
DumpModelProto(model_proto, "unnamed_model");
} else {
DumpModelProto(model_proto, model_proto.name());
if constexpr (std::is_base_of_v<google::protobuf::Message, CpModelProto>) {
// Dump initial model?
if (absl::GetFlag(FLAGS_cp_model_dump_models)) {
DumpModelProto(model_proto, "model");
}
if (absl::GetFlag(FLAGS_cp_model_export_model)) {
if (model_proto.name().empty()) {
DumpModelProto(model_proto, "unnamed_model");
} else {
DumpModelProto(model_proto, model_proto.name());
}
}
}
#endif // __PORTABLE_PLATFORM__
MergeParamsWithFlagsAndDefaults(model->GetOrCreate<SatParameters>());
const SatParameters& params = *model->GetOrCreate<SatParameters>();
@@ -2389,20 +2397,21 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
absl::GetFlag(FLAGS_cp_model_dump_prefix));
RegisterSearchStatisticCallback(model);
#if !defined(__PORTABLE_PLATFORM__)
// Note that the postprocessors are executed in reverse order, so this
// will always dump the response just before it is returned since it is
// the first one we register.
if (absl::GetFlag(FLAGS_cp_model_dump_response)) {
shared_response_manager->AddFinalResponsePostprocessor(
[](CpSolverResponse* response) {
const std::string file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "response.pb.txt");
LOG(INFO) << "Dumping response proto to '" << file << "'.";
CHECK(WriteModelProtoToFile(*response, file));
});
if constexpr (std::is_base_of_v<google::protobuf::Message,
CpSolverResponse>) {
// Note that the postprocessors are executed in reverse order, so this
// will always dump the response just before it is returned since it is
// the first one we register.
if (absl::GetFlag(FLAGS_cp_model_dump_response)) {
shared_response_manager->AddFinalResponsePostprocessor(
[](CpSolverResponse* response) {
const std::string file = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "response.pb.txt");
LOG(INFO) << "Dumping response proto to '" << file << "'.";
CHECK(WriteModelProtoToFile(*response, file));
});
}
}
#endif // __PORTABLE_PLATFORM__
// Always display the final response stats if requested.
// This also copy the logs to the response if requested.
@@ -2456,13 +2465,13 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
// Initialize the time limit from the parameters.
model->GetOrCreate<TimeLimit>()->ResetLimitFromParameters(params);
#if !defined(__PORTABLE_PLATFORM__)
#if !defined(__EMBEDDED_PLATFORM__)
// Register SIGINT handler if requested by the parameters.
if (params.catch_sigint_signal()) {
model->GetOrCreate<SigintHandler>()->Register(
[shared_time_limit]() { shared_time_limit->Stop(); });
}
#endif // __PORTABLE_PLATFORM__
#endif // __EMBEDDED_PLATFORM__
SOLVER_LOG(logger, "");
SOLVER_LOG(logger, "Starting ", CpSatSolverVersion());
@@ -2868,31 +2877,33 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
});
}
#if !defined(__PORTABLE_PLATFORM__)
if (absl::GetFlag(FLAGS_cp_model_dump_models)) {
DumpModelProto(*new_cp_model_proto, "presolved_model");
DumpModelProto(*mapping_proto, "mapping_model");
if constexpr (std::is_base_of_v<google::protobuf::Message, CpModelProto> &&
std::is_base_of_v<google::protobuf::Message, MPModelProto>) {
if (absl::GetFlag(FLAGS_cp_model_dump_models)) {
DumpModelProto(*new_cp_model_proto, "presolved_model");
DumpModelProto(*mapping_proto, "mapping_model");
// If the model is convertible to a MIP, we dump it too.
//
// TODO(user): We could try to dump our linear relaxation too.
MPModelProto mip_model;
if (ConvertCpModelProtoToMPModelProto(*new_cp_model_proto, &mip_model)) {
DumpModelProto(mip_model, "presolved_mp_model");
}
// If the model is convertible to a MIP, we dump it too.
//
// TODO(user): We could try to dump our linear relaxation too.
MPModelProto mip_model;
if (ConvertCpModelProtoToMPModelProto(*new_cp_model_proto, &mip_model)) {
DumpModelProto(mip_model, "presolved_mp_model");
}
// If the model is convertible to a pure SAT one, we dump it too.
std::string cnf_string;
if (ConvertCpModelProtoToCnf(*new_cp_model_proto, &cnf_string)) {
const std::string filename = absl::StrCat(
absl::GetFlag(FLAGS_cp_model_dump_prefix), "presolved_cnf_model.cnf");
LOG(INFO) << "Dumping cnf model to '" << filename << "'.";
const absl::Status status =
file::SetContents(filename, cnf_string, file::Defaults());
if (!status.ok()) LOG(ERROR) << status;
// If the model is convertible to a pure SAT one, we dump it too.
std::string cnf_string;
if (ConvertCpModelProtoToCnf(*new_cp_model_proto, &cnf_string)) {
const std::string filename =
absl::StrCat(absl::GetFlag(FLAGS_cp_model_dump_prefix),
"presolved_cnf_model.cnf");
LOG(INFO) << "Dumping cnf model to '" << filename << "'.";
const absl::Status status =
file::SetContents(filename, cnf_string, file::Defaults());
if (!status.ok()) LOG(ERROR) << status;
}
}
}
#endif // __PORTABLE_PLATFORM__
if (params.stop_after_presolve() || shared_time_limit->LimitReached()) {
int64_t num_terms = 0;
@@ -2955,15 +2966,15 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
LoadDebugSolution(*new_cp_model_proto, model);
if (!model->GetOrCreate<TimeLimit>()->LimitReached()) {
#if defined(__PORTABLE_PLATFORM__)
#if defined(__EMBEDDED_PLATFORM__)
if (/* DISABLES CODE */ (false)) {
// We ignore the multithreading parameter in this case.
#else // __PORTABLE_PLATFORM__
#else // __EMBEDDED_PLATFORM__
if (params.num_workers() > 1 || params.interleave_search() ||
!params.subsolvers().empty() || !params.filter_subsolvers().empty() ||
params.use_ls_only()) {
SolveCpModelParallel(&shared, model);
#endif // __PORTABLE_PLATFORM__
#endif // __EMBEDDED_PLATFORM__
} else {
shared_response_manager->SetUpdateGapIntegralOnEachChange(true);
@@ -2991,14 +3002,12 @@ CpSolverResponse SolveWithParameters(const CpModelProto& model_proto,
return SolveCpModel(model_proto, &model);
}
#if !defined(__PORTABLE_PLATFORM__)
CpSolverResponse SolveWithParameters(const CpModelProto& model_proto,
const std::string& params) {
Model model;
model.Add(NewSatParameters(params));
return SolveCpModel(model_proto, &model);
}
#endif // !__PORTABLE_PLATFORM__
} // namespace sat
} // namespace operations_research

View File

@@ -123,10 +123,8 @@ std::function<void(Model*)> NewBestBoundCallback(
\endcode
* before calling \c SolveCpModel().
*/
#if !defined(__PORTABLE_PLATFORM__)
std::function<SatParameters(Model*)> NewSatParameters(
const std::string& params);
#endif // !__PORTABLE_PLATFORM__
std::function<SatParameters(Model*)> NewSatParameters(
const SatParameters& parameters);

View File

@@ -1070,6 +1070,7 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto,
auto* encoder = model->GetOrCreate<IntegerEncoder>();
auto* mapping = model->GetOrCreate<CpModelMapping>();
auto* repository = model->GetOrCreate<BinaryRelationRepository>();
auto* relations_maps = model->GetOrCreate<BinaryRelationsMaps>();
for (const ConstraintProto& ct : model_proto.constraints()) {
// Load conditional precedences and always true binary relations.
@@ -1135,6 +1136,13 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto,
if (vars.size() == 2) {
repository->Add(Literal(kNoLiteralIndex), {vars[0], coeffs[0]},
{vars[1], coeffs[1]}, rhs_min, rhs_max);
LinearExpression2 expr;
expr.vars[0] = vars[0];
expr.vars[1] = vars[1];
expr.coeffs[0] = coeffs[0];
expr.coeffs[1] = coeffs[1];
relations_maps->AddRelationBounds(expr, rhs_min, rhs_max);
}
} else {
const Literal lit = mapping->Literal(ct.enforcement_literal(0));

View File

@@ -265,29 +265,9 @@ void AddNonOverlappingRectangles(const std::vector<IntervalVariable>& x,
if (num_boxes < params.no_overlap_2d_boolean_relations_limit()) {
auto* implications = model->GetOrCreate<BinaryImplicationGraph>();
auto* sat_solver = model->GetOrCreate<SatSolver>();
auto* encoder = model->GetOrCreate<IntegerEncoder>();
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
DCHECK_EQ(sat_solver->CurrentDecisionLevel(), 0);
// Creates and returns the Boolean equivalent to a <= b.
const auto f = [repository, integer_trail, encoder](
const AffineExpression& a, const AffineExpression& b) {
if (a.var == b.var && a.coeff == b.coeff) {
return (a.constant <= b.constant) ? encoder->GetTrueLiteral()
: encoder->GetFalseLiteral();
}
if (integer_trail->UpperBound(a) <= integer_trail->LowerBound(b)) {
return encoder->GetTrueLiteral();
}
if (integer_trail->LowerBound(a) > integer_trail->UpperBound(b)) {
return encoder->GetFalseLiteral();
}
repository->CreatePrecedenceLiteral(a, b);
const LiteralIndex index = repository->GetPrecedenceLiteral(a, b);
CHECK(index != kNoLiteralIndex);
return Literal(index);
};
for (int i = 0; i < num_boxes; ++i) {
if (repository->IsAbsent(x[i])) continue;
if (repository->IsAbsent(y[i])) continue;
@@ -296,8 +276,10 @@ void AddNonOverlappingRectangles(const std::vector<IntervalVariable>& x,
if (repository->IsAbsent(y[j])) continue;
// At most one of these two x options is true.
const Literal x_ij = f(repository->End(x[i]), repository->Start(x[j]));
const Literal x_ji = f(repository->End(x[j]), repository->Start(x[i]));
const Literal x_ij = repository->GetOrCreatePrecedenceLiteral(
repository->End(x[i]), repository->Start(x[j]));
const Literal x_ji = repository->GetOrCreatePrecedenceLiteral(
repository->End(x[j]), repository->Start(x[i]));
if ((integer_trail->LowerBound(repository->Size(x[i])) > 0 ||
integer_trail->LowerBound(repository->Size(x[j])) > 0) &&
!implications->AddAtMostOne({x_ij, x_ji})) {
@@ -306,8 +288,10 @@ void AddNonOverlappingRectangles(const std::vector<IntervalVariable>& x,
}
// At most one of these two y options is true.
const Literal y_ij = f(repository->End(y[i]), repository->Start(y[j]));
const Literal y_ji = f(repository->End(y[j]), repository->Start(y[i]));
const Literal y_ij = repository->GetOrCreatePrecedenceLiteral(
repository->End(y[i]), repository->Start(y[j]));
const Literal y_ji = repository->GetOrCreatePrecedenceLiteral(
repository->End(y[j]), repository->Start(y[i]));
if ((integer_trail->LowerBound(repository->Size(y[i])) > 0 ||
integer_trail->LowerBound(repository->Size(y[j])) > 0) &&
!implications->AddAtMostOne({y_ij, y_ji})) {

View File

@@ -16,12 +16,8 @@
#include <string>
#if !defined(__PORTABLE_PLATFORM__)
#include "ortools/base/file.h"
#else
class File {};
#endif // !__PORTABLE_PLATFORM__
#include "absl/types/span.h"
#include "ortools/base/file.h"
#include "ortools/sat/sat_base.h"
namespace operations_research {

View File

@@ -26,7 +26,15 @@ void LinearExpression2::SimpleCanonicalization() {
if (coeffs[1] == 0) vars[1] = kNoIntegerVariable;
// Corner case when the underlying variable is the same.
if (vars[0] == vars[1]) {
if (PositiveVariable(vars[0]) == PositiveVariable(vars[1])) {
// Make sure variable are positive before merging.
for (int i = 0; i < 2; ++i) {
if (!VariableIsPositive(vars[i])) {
coeffs[i] = -coeffs[i];
vars[i] = NegationOf(vars[i]);
}
}
coeffs[0] += coeffs[1];
coeffs[1] = 0;
vars[1] = kNoIntegerVariable;
@@ -49,27 +57,30 @@ void LinearExpression2::SimpleCanonicalization() {
}
void LinearExpression2::CanonicalizeAndUpdateBounds(IntegerValue& lb,
IntegerValue& ub) {
// We need to be able to negate without overflow.
CHECK_GE(lb, kMinIntegerValue);
CHECK_LE(ub, kMaxIntegerValue);
IntegerValue& ub,
bool allow_negation) {
SimpleCanonicalization();
if (coeffs[0] == 0 || coeffs[1] == 0) return; // abort.
bool negate = false;
if (coeffs[0] == 0) {
if (coeffs[1] != 0) {
negate = !VariableIsPositive(vars[1]);
if (allow_negation) {
bool negate = false;
if (coeffs[0] == 0) {
if (coeffs[1] != 0) {
negate = !VariableIsPositive(vars[1]);
}
} else {
negate = !VariableIsPositive(vars[0]);
}
if (negate) {
Negate();
// We need to be able to negate without overflow.
CHECK_GE(lb, kMinIntegerValue);
CHECK_LE(ub, kMaxIntegerValue);
std::swap(lb, ub);
lb = -lb;
ub = -ub;
}
} else {
negate = !VariableIsPositive(vars[0]);
}
if (negate) {
Negate();
std::swap(lb, ub);
lb = -lb;
ub = -ub;
}
// Do gcd division.
@@ -108,7 +119,7 @@ bool BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb,
RelationStatus BestBinaryRelationBounds::GetStatus(LinearExpression2 expr,
IntegerValue lb,
IntegerValue ub) {
IntegerValue ub) const {
expr.CanonicalizeAndUpdateBounds(lb, ub);
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) {
return RelationStatus::IS_UNKNOWN;

View File

@@ -358,7 +358,8 @@ struct LinearExpression2 {
void SimpleCanonicalization();
// This fully canonicalize this, and update the given bounds accordingly.
void CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub);
void CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub,
bool allow_negation = false);
bool operator==(const LinearExpression2& o) const {
return vars[0] == o.vars[0] && vars[1] == o.vars[1] &&
@@ -369,6 +370,13 @@ struct LinearExpression2 {
IntegerVariable vars[2];
};
inline std::ostream& operator<<(std::ostream& os,
const LinearExpression2& expr) {
os << absl::StrCat(expr.coeffs[0], " X", expr.vars[0], " + ", expr.coeffs[1],
" X", expr.vars[1]);
return os;
}
template <typename H>
H AbslHashValue(H h, const LinearExpression2& e) {
h = H::combine(std::move(h), e.vars[0]);
@@ -390,7 +398,7 @@ class BestBinaryRelationBounds {
// Returns the known status of expr <= bound.
RelationStatus GetStatus(LinearExpression2 expr, IntegerValue lb,
IntegerValue ub);
IntegerValue ub) const;
private:
// The best bound on the given "canonicalized" expression.

View File

@@ -759,7 +759,7 @@ std::function<BooleanOrIntegerLiteral()> DisjunctivePrecedenceSearchHeuristic(
const auto a = best_helper->GetIntervalDefinition(best_before);
const auto b = best_helper->GetIntervalDefinition(best_after);
return BooleanOrIntegerLiteral(
repo->GetOrCreateDisjunctivePrecedenceLiteral(a, b));
repo->GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial(a, b));
}
return BooleanOrIntegerLiteral();
@@ -867,7 +867,7 @@ std::function<BooleanOrIntegerLiteral()> CumulativePrecedenceSearchHeuristic(
open_tasks.push_back(first_skipped_task);
// TODO(user): If the two box cannot overlap because of high demand, use
// repo.CreateDisjunctivePrecedenceLiteral() instead.
// repo.CreateDisjunctivePrecedenceLiteralIfNonTrivial() instead.
//
// TODO(user): Add heuristic ordering for creating interesting precedence
// first.
@@ -908,8 +908,8 @@ std::function<BooleanOrIntegerLiteral()> CumulativePrecedenceSearchHeuristic(
}
// It shouldn't be able to fail since s can be before t.
CHECK(repo->CreatePrecedenceLiteral(helper->Ends()[s],
helper->Starts()[t]));
CHECK(repo->CreatePrecedenceLiteralIfNonTrivial(
helper->Ends()[s], helper->Starts()[t]));
}
// Branch on that precedence.
@@ -962,7 +962,7 @@ std::function<BooleanOrIntegerLiteral()> CumulativePrecedenceSearchHeuristic(
<< " " << best_helper->TaskDebugString(best_after);
const AffineExpression end_a = best_helper->Ends()[best_before];
const AffineExpression start_b = best_helper->Starts()[best_after];
repo->CreatePrecedenceLiteral(end_a, start_b);
repo->CreatePrecedenceLiteralIfNonTrivial(end_a, start_b);
return BooleanOrIntegerLiteral(
repo->GetPrecedenceLiteral(end_a, start_b));
}
@@ -1421,7 +1421,7 @@ LiteralIndex IntegerSearchHelper::GetDecisionLiteral(
bool IntegerSearchHelper::GetDecision(
const std::function<BooleanOrIntegerLiteral()>& f, LiteralIndex* decision) {
*decision = kNoLiteralIndex;
while (!time_limit_->LimitReached()) {
do {
BooleanOrIntegerLiteral new_decision;
if (integer_trail_->InPropagationLoop()) {
const IntegerVariable var =
@@ -1451,7 +1451,7 @@ bool IntegerSearchHelper::GetDecision(
*decision = GetDecisionLiteral(new_decision);
if (*decision != kNoLiteralIndex) break;
}
} while (!time_limit_->LimitReached());
return true;
}

View File

@@ -18,15 +18,17 @@
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/meta/type_traits.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "ortools/base/strong_vector.h"
#include "ortools/sat/clause.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_base.h"
#include "ortools/sat/integer_expr.h"
#include "ortools/sat/linear_constraint.h"
#include "ortools/sat/model.h"
#include "ortools/sat/no_overlap_2d_helper.h"
#include "ortools/sat/precedences.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/sat_solver.h"
#include "ortools/sat/scheduling_helpers.h"
@@ -35,6 +37,14 @@
namespace operations_research {
namespace sat {
IntervalsRepository::IntervalsRepository(Model* model)
: model_(model),
assignment_(model->GetOrCreate<Trail>()->Assignment()),
sat_solver_(model->GetOrCreate<SatSolver>()),
implications_(model->GetOrCreate<BinaryImplicationGraph>()),
integer_trail_(model->GetOrCreate<IntegerTrail>()),
relations_maps_(model->GetOrCreate<BinaryRelationsMaps>()) {}
IntervalVariable IntervalsRepository::CreateInterval(IntegerVariable start,
IntegerVariable end,
IntegerVariable size,
@@ -78,7 +88,7 @@ IntervalVariable IntervalsRepository::CreateInterval(AffineExpression start,
void IntervalsRepository::CreateDisjunctivePrecedenceLiteral(
IntervalVariable a, IntervalVariable b) {
GetOrCreateDisjunctivePrecedenceLiteral(
GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial(
IntervalDefinition{.start = Start(a),
.end = End(a),
.size = Size(a),
@@ -93,7 +103,8 @@ void IntervalsRepository::CreateDisjunctivePrecedenceLiteral(
: std::nullopt});
}
LiteralIndex IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteral(
LiteralIndex
IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial(
const IntervalDefinition& a, const IntervalDefinition& b) {
auto it = disjunctive_precedences_.find({a, b});
if (it != disjunctive_precedences_.end()) return it->second.Index();
@@ -143,7 +154,26 @@ LiteralIndex IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteral(
return kNoLiteralIndex;
}
// Abort if the relation is already known.
if (relations_maps_->GetPrecedenceStatus(a.end, b.start) ==
RelationStatus::IS_TRUE ||
relations_maps_->GetPrecedenceStatus(b.end, a.start) ==
RelationStatus::IS_TRUE) {
return kNoLiteralIndex;
}
// Create a new literal.
//
// TODO(user): If there are no enforcement and we already have at one of:
// - s <=> a.end <= b.start
// - t <=> b.end <= a.start
// We could use (s, not(s)) or (not(t), t) and make sure s = not(t) if both
// exists.
//
// TODO(user): Otherwise, an alternative solution is to create s and t (can be
// one more Boolean though), and have enforcement => s + t == 1. The later
// might not even be needed though, since interval equation should already
// enforce it.
const BooleanVariable boolean_var = sat_solver_->NewBooleanVariable();
const Literal a_before_b = Literal(boolean_var, true);
disjunctive_precedences_.insert({{a, b}, a_before_b});
@@ -151,9 +181,10 @@ LiteralIndex IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteral(
// Also insert it in precedences.
if (enforcement_literals.empty()) {
// TODO(user): also add the reverse like start_b + 1 <= end_a if negated?
precedences_.insert({{a.end, b.start}, a_before_b});
precedences_.insert({{b.end, a.start}, a_before_b.Negated()});
relations_maps_->AddReifiedPrecedenceIfNonTrivial(a_before_b, a.end,
b.start);
relations_maps_->AddReifiedPrecedenceIfNonTrivial(a_before_b.Negated(),
b.end, a.start);
}
enforcement_literals.push_back(a_before_b);
@@ -179,25 +210,22 @@ LiteralIndex IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteral(
return a_before_b;
}
bool IntervalsRepository::CreatePrecedenceLiteral(AffineExpression x,
AffineExpression y) {
if (precedences_.contains({x, y})) return false;
bool IntervalsRepository::CreatePrecedenceLiteralIfNonTrivial(
AffineExpression x, AffineExpression y) {
const LiteralIndex index = relations_maps_->GetReifiedPrecedence(x, y);
if (index != kNoLiteralIndex) return false;
// We want l => x <= y and not(l) => x > y <=> y + 1 <= x
// Do not create l if the relation is always true or false.
if (integer_trail_->UpperBound(x) <= integer_trail_->LowerBound(y)) {
return false;
}
if (integer_trail_->LowerBound(x) > integer_trail_->UpperBound(y)) {
if (relations_maps_->GetPrecedenceStatus(x, y) !=
RelationStatus::IS_UNKNOWN) {
return false;
}
// Create a new literal.
const BooleanVariable boolean_var = sat_solver_->NewBooleanVariable();
const Literal x_before_y = Literal(boolean_var, true);
// TODO(user): Also add {{y_plus_one, x}, x_before_y.Negated()} ?
precedences_.insert({{x, y}, x_before_y});
relations_maps_->AddReifiedPrecedenceIfNonTrivial(x_before_y, x, y);
AffineExpression y_plus_one = y;
y_plus_one.constant += 1;
@@ -208,9 +236,20 @@ bool IntervalsRepository::CreatePrecedenceLiteral(AffineExpression x,
LiteralIndex IntervalsRepository::GetPrecedenceLiteral(
AffineExpression x, AffineExpression y) const {
const auto it = precedences_.find({x, y});
if (it != precedences_.end()) return it->second.Index();
return kNoLiteralIndex;
return relations_maps_->GetReifiedPrecedence(x, y);
}
Literal IntervalsRepository::GetOrCreatePrecedenceLiteral(AffineExpression x,
AffineExpression y) {
{
const LiteralIndex index = GetPrecedenceLiteral(x, y);
if (index != kNoLiteralIndex) return Literal(index);
}
CHECK(CreatePrecedenceLiteralIfNonTrivial(x, y));
const LiteralIndex index = relations_maps_->GetReifiedPrecedence(x, y);
CHECK_NE(index, kNoLiteralIndex);
return Literal(index);
}
// TODO(user): Ideally we should sort the vector of variables, but right now

View File

@@ -42,12 +42,7 @@ namespace sat {
// provides many helper functions to add precedences relation between intervals.
class IntervalsRepository {
public:
explicit IntervalsRepository(Model* model)
: model_(model),
assignment_(model->GetOrCreate<Trail>()->Assignment()),
sat_solver_(model->GetOrCreate<SatSolver>()),
implications_(model->GetOrCreate<BinaryImplicationGraph>()),
integer_trail_(model->GetOrCreate<IntegerTrail>()) {}
explicit IntervalsRepository(Model* model);
// This type is neither copyable nor movable.
IntervalsRepository(const IntervalsRepository&) = delete;
@@ -149,19 +144,25 @@ class IntervalsRepository {
// If such literal already exists this returns it.
void CreateDisjunctivePrecedenceLiteral(IntervalVariable a,
IntervalVariable b);
LiteralIndex GetOrCreateDisjunctivePrecedenceLiteral(
LiteralIndex GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial(
const IntervalDefinition& a, const IntervalDefinition& b);
// Creates a literal l <=> y >= x.
// Returns true if such literal is "non-trivial" and was created.
bool CreatePrecedenceLiteral(AffineExpression x, AffineExpression y);
bool CreatePrecedenceLiteralIfNonTrivial(AffineExpression x,
AffineExpression y);
// Returns a literal l <=> y >= x if it exist or kNoLiteralIndex
// otherwise. This could be the one created by
// CreateDisjunctivePrecedenceLiteral() or CreatePrecedenceLiteral().
// CreateDisjunctivePrecedenceLiteral() or
// CreatePrecedenceLiteralIfNonTrivial().
LiteralIndex GetPrecedenceLiteral(AffineExpression x,
AffineExpression y) const;
// Combines the two calls. Note that we will only create literals when the
// relation is not known.
Literal GetOrCreatePrecedenceLiteral(AffineExpression x, AffineExpression y);
const std::vector<SchedulingConstraintHelper*>& AllDisjunctiveHelpers()
const {
return disjunctive_helpers_;
@@ -188,6 +189,7 @@ class IntervalsRepository {
SatSolver* sat_solver_;
BinaryImplicationGraph* implications_;
IntegerTrail* integer_trail_;
BinaryRelationsMaps* relations_maps_;
// Literal indicating if the tasks is executed. Tasks that are always executed
// will have a kNoLiteralIndex entry in this vector.
@@ -212,16 +214,10 @@ class IntervalsRepository {
SchedulingDemandHelper*>
demand_helper_repository_;
// Disjunctive and normal precedences.
//
// Note that for normal precedences, we use directly the affine expression so
// that if many intervals share the same start, we don't re-create Booleans
// for no reason.
// Disjunctive precedences.
absl::flat_hash_map<std::pair<IntervalDefinition, IntervalDefinition>,
Literal>
disjunctive_precedences_;
absl::flat_hash_map<std::pair<AffineExpression, AffineExpression>, Literal>
precedences_;
// Disjunctive/Cumulative helpers_.
std::vector<SchedulingConstraintHelper*> disjunctive_helpers_;

View File

@@ -171,6 +171,17 @@ LinearExpression LinearConstraintBuilder::BuildExpression() {
return result;
}
double LinearConstraint::NormalizedViolation(
const util_intops::StrongVector<IntegerVariable, double>& lp_values) const {
const double activity = ComputeActivity(*this, lp_values);
const double violation =
std::max(activity - ToDouble(ub), ToDouble(lb) - activity);
if (violation <= 0.0) return 0.0;
const double l2_norm = ComputeL2Norm(*this);
return violation / l2_norm;
}
double ComputeActivity(
const LinearConstraint& constraint,
const util_intops::StrongVector<IntegerVariable, double>& values) {

View File

@@ -66,6 +66,12 @@ struct LinearConstraint {
LinearConstraint() = default;
LinearConstraint(IntegerValue _lb, IntegerValue _ub) : lb(_lb), ub(_ub) {}
// Compute the normalized violation of the constraint.
// For a cut, this is the usual definition of its efficacy.
double NormalizedViolation(
const util_intops::StrongVector<IntegerVariable, double>& lp_values)
const;
// Resize the LinearConstraint to have space for num_terms. We always
// re-allocate if the size is different to always be tight in memory.
void resize(int size) {
@@ -234,7 +240,7 @@ class LinearConstraintBuilder {
ABSL_MUST_USE_RESULT bool AddDecomposedProduct(
absl::Span<const LiteralValueValue> product);
// Add literal * coeff to the constaint. Returns false and do nothing if the
// Add literal * coeff to the constraint. Returns false and do nothing if the
// given literal didn't have an integer view.
ABSL_MUST_USE_RESULT bool AddLiteralTerm(
Literal lit, IntegerValue coeff = IntegerValue(1));
@@ -312,8 +318,8 @@ double ComputeActivity(
// linear relaxation. This is a bit relaxed compared to what we require for
// generic linear constraint that are used in our CP propagators.
//
// If this check pass, our constraint should be safe to use in our simplication
// code, our cut computation, etc...
// If this check pass, our constraint should be safe to use in our
// simplification code, our cut computation, etc...
bool PossibleOverflow(const IntegerTrail& integer_trail,
const LinearConstraint& constraint);
@@ -329,7 +335,7 @@ double ScalarProduct(const LinearConstraint& constraint1,
const LinearConstraint& constraint2);
// Computes the GCD of the constraint coefficient, and divide them by it. This
// also tighten the constraint bounds assumming all the variables are integer.
// also tighten the constraint bounds assuming all the variables are integer.
void DivideByGCD(LinearConstraint* constraint);
// Removes the entries with a coefficient of zero.

View File

@@ -29,7 +29,6 @@
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/log/vlog_is_on.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
@@ -979,11 +978,8 @@ void TopNCuts::AddCut(
LinearConstraint ct, absl::string_view name,
const util_intops::StrongVector<IntegerVariable, double>& lp_solution) {
if (ct.num_terms == 0) return;
const double activity = ComputeActivity(ct, lp_solution);
const double violation =
std::max(activity - ToDouble(ct.ub), ToDouble(ct.lb) - activity);
const double l2_norm = ComputeL2Norm(ct);
cuts_.Add({std::string(name), std::move(ct)}, violation / l2_norm);
const double normalized_violation = ct.NormalizedViolation(lp_solution);
cuts_.Add({std::string(name), std::move(ct)}, normalized_violation);
}
void TopNCuts::TransferToManager(LinearConstraintManager* manager) {

View File

@@ -13,7 +13,9 @@
#include "ortools/sat/linear_constraint.h"
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <utility>
#include <vector>
@@ -44,8 +46,11 @@ TEST(ComputeActivityTest, BasicBehavior) {
util_intops::StrongVector<IntegerVariable, double> values = {0.5, 0.0, 1.4,
0.0, -2.1, 0.0};
EXPECT_NEAR(ComputeActivity(ct.Build(), values), 1 * 0.5 - 2 * 1.4 - 3 * 2.1,
1e-6);
const double expected_activity = 1 * 0.5 - 2 * 1.4 - 3 * 2.1;
EXPECT_NEAR(ComputeActivity(ct.Build(), values), expected_activity, 1e-6);
const double expected_violation =
std::abs(expected_activity) / std::sqrt(1 + 4 + 9);
EXPECT_NEAR(ct.Build().NormalizedViolation(values), expected_violation, 1e-6);
}
TEST(ComputeActivityTest, EmptyConstraint) {

View File

@@ -1487,5 +1487,153 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints(
return num_added_constraints;
}
BinaryRelationsMaps::BinaryRelationsMaps(Model* model)
: integer_trail_(model->GetOrCreate<IntegerTrail>()),
integer_encoder_(model->GetOrCreate<IntegerEncoder>()),
shared_stats_(model->GetOrCreate<SharedStatistics>()) {
int index = 0;
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
[index = index, trail = model->GetOrCreate<Trail>(), this]() mutable {
DCHECK_EQ(trail->CurrentDecisionLevel(), 0);
absl::flat_hash_set<Literal> relevant_true_literals;
for (; index < trail->Index(); ++index) {
const Literal l = (*trail)[index];
if (variable_appearing_in_reified_relations_.contains(l.Variable())) {
relevant_true_literals.insert(l);
}
}
if (relevant_true_literals.empty()) return true;
// Linear scan.
for (const auto [l, expr, ub] : all_reified_relations_) {
if (relevant_true_literals.contains(l)) {
AddRelationBounds(expr, kMinIntegerValue, ub);
VLOG(2) << "New fixed precedence: " << expr << " <= " << ub
<< " (was reified by " << l << ")";
} else if (relevant_true_literals.contains(l.Negated())) {
AddRelationBounds(expr, ub + 1, kMaxIntegerValue);
VLOG(2) << "New fixed precedence: " << expr << " > " << ub
<< " (was reified by not(" << l << "))";
}
}
return true;
});
}
BinaryRelationsMaps::~BinaryRelationsMaps() {
if (!VLOG_IS_ON(1)) return;
std::vector<std::pair<std::string, int64_t>> stats;
stats.push_back({"BinaryRelationsMaps/num_relations", num_updates_});
shared_stats_->AddStats(stats);
}
std::pair<IntegerValue, IntegerValue>
BinaryRelationsMaps::GetImpliedLevelZeroBounds(
const LinearExpression2& expr) const {
// Compute the implied bounds on the expression.
IntegerValue implied_lb = 0;
IntegerValue implied_ub = 0;
if (expr.coeffs[0] != 0) {
CHECK_GE(expr.vars[0], 0);
implied_lb +=
expr.coeffs[0] * integer_trail_->LevelZeroLowerBound(expr.vars[0]);
implied_ub +=
expr.coeffs[0] * integer_trail_->LevelZeroUpperBound(expr.vars[0]);
}
if (expr.coeffs[1] != 0) {
CHECK_GE(expr.vars[1], 0);
implied_lb +=
expr.coeffs[1] * integer_trail_->LevelZeroLowerBound(expr.vars[1]);
implied_ub +=
expr.coeffs[1] * integer_trail_->LevelZeroUpperBound(expr.vars[1]);
}
return {implied_lb, implied_ub};
}
void BinaryRelationsMaps::AddRelationBounds(LinearExpression2 expr,
IntegerValue lb, IntegerValue ub) {
expr.CanonicalizeAndUpdateBounds(lb, ub);
const auto [implied_lb, implied_ub] = GetImpliedLevelZeroBounds(expr);
lb = std::max(lb, implied_lb);
ub = std::min(ub, implied_ub);
if (lb > ub) return; // unsat ??
if (lb == implied_lb && ub == implied_ub) return; // trivially true.
if (best_upper_bounds_.Add(expr, lb, ub)) {
// TODO(user): Also push them to a global shared repository after
// remapping IntegerVariable to proto indices.
++num_updates_;
}
}
RelationStatus BinaryRelationsMaps::GetStatus(LinearExpression2 expr,
IntegerValue lb,
IntegerValue ub) const {
expr.CanonicalizeAndUpdateBounds(lb, ub);
const auto [implied_lb, implied_ub] = GetImpliedLevelZeroBounds(expr);
lb = std::max(lb, implied_lb);
ub = std::min(ub, implied_ub);
// Returns directly if the status can be derived from the implied bounds.
if (lb > ub) return RelationStatus::IS_FALSE;
if (lb == implied_lb && ub == implied_ub) return RelationStatus::IS_TRUE;
// Relax as best_upper_bounds_.GetStatus() might have older bounds.
if (lb == implied_lb) lb = kMinIntegerValue;
if (ub == implied_ub) ub = kMaxIntegerValue;
return best_upper_bounds_.GetStatus(expr, lb, ub);
}
std::pair<LinearExpression2, IntegerValue> BinaryRelationsMaps::FromDifference(
const AffineExpression& a, const AffineExpression& b) const {
LinearExpression2 expr;
expr.vars[0] = a.var;
expr.vars[1] = b.var;
expr.coeffs[0] = a.coeff;
expr.coeffs[1] = -b.coeff;
IntegerValue lb = kMinIntegerValue; // unused.
IntegerValue ub = b.constant - a.constant;
expr.CanonicalizeAndUpdateBounds(lb, ub, /*allow_negation=*/false);
return {std::move(expr), ub};
}
RelationStatus BinaryRelationsMaps::GetPrecedenceStatus(
AffineExpression a, AffineExpression b) const {
const auto [expr, ub] = FromDifference(a, b);
return GetStatus(expr, kMinIntegerValue, ub);
}
void BinaryRelationsMaps::AddReifiedPrecedenceIfNonTrivial(Literal l,
AffineExpression a,
AffineExpression b) {
const auto [expr, ub] = FromDifference(a, b);
const RelationStatus status = GetStatus(expr, kMinIntegerValue, ub);
if (status != RelationStatus::IS_UNKNOWN) return;
relation_to_lit_.insert({{expr, ub}, l});
variable_appearing_in_reified_relations_.insert(l.Variable());
all_reified_relations_.push_back({l, expr, ub});
}
LiteralIndex BinaryRelationsMaps::GetReifiedPrecedence(AffineExpression a,
AffineExpression b) {
const auto [expr, ub] = FromDifference(a, b);
const RelationStatus status = GetStatus(expr, kMinIntegerValue, ub);
if (status == RelationStatus::IS_TRUE) {
return integer_encoder_->GetTrueLiteral().Index();
}
if (status == RelationStatus::IS_FALSE) {
return integer_encoder_->GetFalseLiteral().Index();
}
const auto it = relation_to_lit_.find({expr, ub});
if (it == relation_to_lit_.end()) return kNoLiteralIndex;
return it->second;
}
} // namespace sat
} // namespace operations_research

View File

@@ -506,6 +506,7 @@ struct Relation {
class BinaryRelationRepository {
public:
int size() const { return relations_.size(); }
// The returned relation is guaranteed to only have positive variables.
const Relation& relation(int index) const { return relations_[index]; }
@@ -574,6 +575,64 @@ class BinaryRelationRepository {
var_pair_to_relations_;
};
// TODO(user): Merge with BinaryRelationRepository. Note that this one provides
// different indexing though, so it could be kept separate. The
// LinearExpression2 data structure is also slightly more efficient.
class BinaryRelationsMaps {
public:
explicit BinaryRelationsMaps(Model* model);
~BinaryRelationsMaps();
// This mainly wraps BestBinaryRelationBounds, but in addition it checks the
// current LevelZero variable bounds to detect trivially true or false
// relation.
void AddRelationBounds(LinearExpression2 expr, IntegerValue lb,
IntegerValue ub);
RelationStatus GetStatus(LinearExpression2 expr, IntegerValue lb,
IntegerValue ub) const;
// Return the status of a <= b;
RelationStatus GetPrecedenceStatus(AffineExpression a,
AffineExpression b) const;
// Register the fact that l <=> ( a <= b ).
// These are considered equivalence relation.
void AddReifiedPrecedenceIfNonTrivial(Literal l, AffineExpression a,
AffineExpression b);
// Returns kNoLiteralIndex if we don't have a literal <=> ( a <= b ), or
// returns that literal if we have one. Note that we will return the
// true/false literal if the status is known at level zero.
LiteralIndex GetReifiedPrecedence(AffineExpression a, AffineExpression b);
private:
// Return the pair (a - b) <= rhs.
std::pair<LinearExpression2, IntegerValue> FromDifference(
const AffineExpression& a, const AffineExpression& b) const;
std::pair<IntegerValue, IntegerValue> GetImpliedLevelZeroBounds(
const LinearExpression2& expr) const;
IntegerTrail* integer_trail_;
IntegerEncoder* integer_encoder_;
SharedStatistics* shared_stats_;
BestBinaryRelationBounds best_upper_bounds_;
int64_t num_updates_ = 0;
// This stores relations l <=> (linear2 <= rhs).
absl::flat_hash_map<std::pair<LinearExpression2, IntegerValue>, Literal>
relation_to_lit_;
// This is used to detect relations that become fixed at level zero and
// "upgrade" them to non-enforced relations. Because we only do that when
// we fix variable, a linear scan shouldn't be too bad and is relatively
// compact memory wise.
absl::flat_hash_set<BooleanVariable> variable_appearing_in_reified_relations_;
std::vector<std::tuple<Literal, LinearExpression2, IntegerValue>>
all_reified_relations_;
};
// Detects if at least one of a subset of linear of size 2 or 1, touching the
// same variable, must be true. When this is the case we add a new propagator to
// propagate that fact.

View File

@@ -122,6 +122,12 @@ void GetRelationshipForConstraint(const ConstraintProto& ct,
}
return;
}
case ConstraintProto::kExactlyOne: {
for (const int lit : ct.exactly_one().literals()) {
deducible_vars->insert(PositiveRef(lit));
}
return;
}
default:
break;
}
@@ -613,6 +619,20 @@ bool ComputeAllVariablesFromPrimaryVariables(
product -= target.offset();
(*solution)[var] = product / coeff_of_var;
} break;
case ConstraintProto::kExactlyOne: {
(*solution)[var] = 0;
int sum = 0;
for (const int lit : ct.exactly_one().literals()) {
const int positive_ref = PositiveRef(lit);
DCHECK(positive_ref == var ||
!dependent_variables_set.IsSet(positive_ref));
sum += RefIsPositive(lit) ? (*solution)[positive_ref]
: 1 - (*solution)[positive_ref];
}
if (sum != 1) {
(*solution)[var] ^= 1;
}
} break;
default:
break;
}

View File

@@ -26,6 +26,8 @@ namespace {
using ::google::protobuf::contrib::parse_proto::ParseTestProto;
using ::testing::Contains;
using ::testing::ElementsAre;
using ::testing::EqualsProto;
using ::testing::Pair;
TEST(PrimaryVariablesTest, BasicExample) {
@@ -114,6 +116,25 @@ TEST(PrimaryVariablesTest, WithIntProd) {
EXPECT_EQ(all_variables, solution);
}
TEST(PrimaryVariablesTest, WithExactlyOne) {
const CpModelProto model = ParseTestProto(R"pb(
variables { domain: [ 0, 1 ] }
variables { domain: [ 0, 1 ] }
variables { domain: [ 0, 1 ] }
variables { domain: [ 0, 1 ] }
variables { domain: [ 0, 1 ] }
constraints { exactly_one { literals: [ 0, 1, 2, 3 ] } }
)pb");
const VariableRelationships relationships =
ComputeVariableRelationships(model);
EXPECT_EQ(relationships.secondary_variables.size(), 1);
const ConstraintProto expected = ParseTestProto(R"pb(
exactly_one { literals: [ 0, 1, 2, 3 ] }
)pb");
EXPECT_THAT(relationships.dependency_resolution_constraint,
ElementsAre(EqualsProto(expected)));
}
} // namespace
} // namespace sat
} // namespace operations_research

View File

@@ -2485,10 +2485,11 @@ void SatSolver::MinimizeConflictRecursively(std::vector<Literal>* conflict) {
// be infered by some other variables in the conflict.
// Note that we can skip the first position since this is the 1-UIP.
int index = 1;
TimeLimitCheckEveryNCalls time_limit_check(100, time_limit_);
for (int i = 1; i < conflict->size(); ++i) {
const BooleanVariable var = (*conflict)[i].Variable();
const AssignmentInfo& info = trail_->Info(var);
if (time_limit_->LimitReached() ||
if (time_limit_check.LimitReached() ||
info.type == AssignmentType::kSearchDecision ||
info.trail_index <= min_trail_index_per_level_[info.level] ||
!CanBeInferedFromConflictVariables(var)) {

File diff suppressed because it is too large Load Diff

View File

@@ -18,11 +18,11 @@
#include <string>
#include <vector>
#include "absl/types/span.h"
#include "ortools/sat/cuts.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_base.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/scheduling_helpers.h"
namespace operations_research {
@@ -100,21 +100,34 @@ CutGenerator CreateNoOverlapCompletionTimeCutGenerator(
// Internal methods and data structures, useful for testing.
// Base event type for scheduling cuts.
struct BaseEvent {
BaseEvent(int t, SchedulingConstraintHelper* x_helper);
// Stores the event for a task (interval, demand).
// For a no_overlap constraint, demand is always between 0 and 1.
// For a cumulative constraint, demand must be between 0 and capacity_max.
struct CompletionTimeEvent {
CompletionTimeEvent(int t, SchedulingConstraintHelper* x_helper,
SchedulingDemandHelper* demands_helper);
// Cache of the intervals bound on the x direction.
IntegerValue x_start_min;
IntegerValue x_start_max;
IntegerValue x_end_min;
IntegerValue x_end_max;
IntegerValue x_size_min;
IntegerValue x_size_max;
// The index of the task in the helper.
int task_index;
// Cache of the bounds on the y direction.
IntegerValue y_size_min;
IntegerValue y_size_max;
// Cache of the bounds of the interval.
IntegerValue start_min;
IntegerValue start_max;
IntegerValue end_min;
IntegerValue end_max;
IntegerValue size_min;
// The lp value of the end of the interval.
AffineExpression end;
double lp_end = 0.0;
// Cache of the bounds of the demand.
IntegerValue demand_min;
// If we know that the size on y is fixed, we can use some heuristic to
// compute the maximum subset sums under the capacity and use that instead
// of the full capacity.
bool demand_is_fixed = false;
// The energy min of this event.
IntegerValue energy_min;
@@ -127,24 +140,6 @@ struct BaseEvent {
// model.
bool use_energy = false;
// If we know that the size on y is fixed, we can use some heuristic to
// compute the maximum subset sums under the capacity and use that instead
// of the full capacity.
bool y_size_is_fixed() const { return y_size_min == y_size_max; }
void PropagateDecomposedEnergy(const VariablesAssignment& assignment);
};
// Stores the event for a rectangle along the two axis x and y.
// For a no_overlap constraint, y is always of size 1 between 0 and 1.
// For a cumulative constraint, y is the demand that must be between 0 and
// capacity_max.
struct CtEvent : BaseEvent {
CtEvent(int t, SchedulingConstraintHelper* x_helper);
// The lp value of the end of the x interval.
AffineExpression x_end;
double x_lp_end;
// Indicates if the cut is lifted, that is if it includes tasks that are not
// strictly contained in the current time window.
bool lifted = false;
@@ -160,30 +155,11 @@ struct CtEvent : BaseEvent {
// small, like <= 10. They should also starts in index order.
//
// Optim: If both sums are proven <= to the corresponding threshold, we abort.
struct PermutableEvent {
PermutableEvent(int i, CtEvent e)
: index(i),
start_min(e.x_start_min),
start_max(e.x_start_max),
size(e.x_size_min),
demand(e.y_size_min),
weight(e.y_size_min) {}
bool operator<(const PermutableEvent& o) const { return index < o.index; }
int index; // for < to be used by std::next_permutation().
IntegerValue start_min;
IntegerValue start_max;
IntegerValue size;
IntegerValue demand;
IntegerValue weight;
};
bool ComputeMinSumOfWeightedEndMins(std::vector<PermutableEvent>& events,
IntegerValue capacity_max,
IntegerValue& min_sum_of_end_mins,
IntegerValue& min_sum_of_weighted_end_mins,
IntegerValue unweighted_threshold,
IntegerValue weighted_threshold);
bool ComputeMinSumOfWeightedEndMins(
absl::Span<const CompletionTimeEvent> events, IntegerValue capacity_max,
double sum_of_ends_lp, double sum_of_weighted_ends_lp,
IntegerValue& min_sum_of_end_mins,
IntegerValue& min_sum_of_weighted_end_mins);
} // namespace sat
} // namespace operations_research

View File

@@ -15,7 +15,6 @@
#include <stdint.h>
#include <functional>
#include <optional>
#include <string>
#include <vector>
@@ -37,6 +36,7 @@
#include "ortools/sat/linear_constraint_manager.h"
#include "ortools/sat/model.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/scheduling_helpers.h"
#include "ortools/util/strong_integers.h"
namespace operations_research {
@@ -398,21 +398,21 @@ TEST(ComputeMinSumOfEndMinsTest, CombinationOf3) {
SchedulingConstraintHelper* helper =
model.GetOrCreate<IntervalsRepository>()->GetOrCreateHelper({i1, i2, i3});
CtEvent e1(0, helper);
e1.y_size_min = two;
CtEvent e2(1, helper);
e2.y_size_min = one;
CtEvent e3(2, helper);
e3.y_size_min = one;
std::vector<PermutableEvent> events = {{0, e1}, {1, e2}, {1, e3}};
SchedulingDemandHelper* demands_helper =
new SchedulingDemandHelper({two, one, one}, helper, &model);
model.TakeOwnership(demands_helper);
CompletionTimeEvent e1(0, helper, demands_helper);
CompletionTimeEvent e2(1, helper, demands_helper);
CompletionTimeEvent e3(2, helper, demands_helper);
const std::vector<CompletionTimeEvent> events = {e1, e2, e3};
IntegerValue min_sum_of_end_mins(0);
IntegerValue min_sum_of_weighted_end_mins(0);
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(
events, two, min_sum_of_end_mins, min_sum_of_weighted_end_mins,
kMinIntegerValue, kMinIntegerValue));
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(events, two, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
EXPECT_EQ(min_sum_of_end_mins, 17);
EXPECT_EQ(min_sum_of_weighted_end_mins, 21);
EXPECT_EQ(min_sum_of_weighted_end_mins, 86);
}
TEST(ComputeMinSumOfEndMinsTest, CombinationOf3ConstraintStart) {
@@ -442,21 +442,22 @@ TEST(ComputeMinSumOfEndMinsTest, CombinationOf3ConstraintStart) {
SchedulingConstraintHelper* helper =
model.GetOrCreate<IntervalsRepository>()->GetOrCreateHelper({i1, i2, i3});
CtEvent e1(0, helper);
e1.y_size_min = two;
CtEvent e2(1, helper);
e2.y_size_min = one;
CtEvent e3(2, helper);
e3.y_size_min = one;
std::vector<PermutableEvent> events = {{0, e1}, {1, e2}, {2, e3}};
SchedulingDemandHelper* demands_helper =
new SchedulingDemandHelper({two, one, one}, helper, &model);
model.TakeOwnership(demands_helper);
IntegerValue min_sum_of_end_mins(0);
IntegerValue min_sum_of_weighted_end_mins(0);
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(
events, two, min_sum_of_end_mins, min_sum_of_weighted_end_mins,
kMinIntegerValue, kMinIntegerValue));
CompletionTimeEvent e1(0, helper, demands_helper);
CompletionTimeEvent e2(1, helper, demands_helper);
CompletionTimeEvent e3(2, helper, demands_helper);
const std::vector<CompletionTimeEvent> events = {e1, e2, e3};
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
ASSERT_TRUE(ComputeMinSumOfWeightedEndMins(events, two, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
EXPECT_EQ(min_sum_of_end_mins, 18);
EXPECT_EQ(min_sum_of_weighted_end_mins, 21);
EXPECT_EQ(min_sum_of_weighted_end_mins, 86);
}
TEST(ComputeMinSumOfEndMinsTest, Infeasible) {
@@ -486,19 +487,20 @@ TEST(ComputeMinSumOfEndMinsTest, Infeasible) {
SchedulingConstraintHelper* helper =
model.GetOrCreate<IntervalsRepository>()->GetOrCreateHelper({i1, i2, i3});
CtEvent e1(0, helper);
e1.y_size_min = two;
CtEvent e2(1, helper);
e2.y_size_min = one;
CtEvent e3(2, helper);
e3.y_size_min = one;
std::vector<PermutableEvent> events = {{0, e1}, {1, e2}, {2, e3}};
SchedulingDemandHelper* demands_helper =
new SchedulingDemandHelper({two, one, one}, helper, &model);
model.TakeOwnership(demands_helper);
IntegerValue min_sum_of_end_mins(0);
IntegerValue min_sum_of_weighted_end_mins(0);
ASSERT_FALSE(ComputeMinSumOfWeightedEndMins(
events, two, min_sum_of_end_mins, min_sum_of_weighted_end_mins,
kMinIntegerValue, kMinIntegerValue));
CompletionTimeEvent e1(0, helper, demands_helper);
CompletionTimeEvent e2(1, helper, demands_helper);
CompletionTimeEvent e3(2, helper, demands_helper);
const std::vector<CompletionTimeEvent> events = {e1, e2, e3};
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
ASSERT_FALSE(ComputeMinSumOfWeightedEndMins(events, two, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
}
int64_t ExactMakespan(absl::Span<const int> sizes, std::vector<int>& demands,
@@ -539,18 +541,25 @@ int64_t ExactMakespanBruteForce(absl::Span<const int> sizes,
SchedulingConstraintHelper* helper =
model.GetOrCreate<IntervalsRepository>()->GetOrCreateHelper(intervals);
std::vector<PermutableEvent> events;
std::vector<AffineExpression> demands_expr;
for (int i = 0; i < demands.size(); ++i) {
CtEvent e(i, helper);
e.y_size_min = demands[i];
events.emplace_back(i, e);
demands_expr.push_back(AffineExpression(demands[i]));
}
SchedulingDemandHelper* demands_helper =
new SchedulingDemandHelper(demands_expr, helper, &model);
model.TakeOwnership(demands_helper);
std::vector<CompletionTimeEvent> events;
for (int i = 0; i < demands.size(); ++i) {
CompletionTimeEvent e(i, helper, demands_helper);
events.push_back(e);
}
IntegerValue min_sum_of_end_mins(0);
IntegerValue min_sum_of_weighted_end_mins(0);
EXPECT_TRUE(ComputeMinSumOfWeightedEndMins(
events, IntegerValue(capacity), min_sum_of_end_mins,
min_sum_of_weighted_end_mins, kMinIntegerValue, kMinIntegerValue));
IntegerValue min_sum_of_end_mins = 0;
IntegerValue min_sum_of_weighted_end_mins = 0;
EXPECT_TRUE(ComputeMinSumOfWeightedEndMins(events, capacity, 0.01, 0.01,
min_sum_of_end_mins,
min_sum_of_weighted_end_mins));
return min_sum_of_end_mins.value();
}

View File

@@ -21,11 +21,9 @@
#include <vector>
#include "absl/log/check.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "ortools/base/logging.h"
#include "ortools/base/strong_vector.h"
#include "ortools/sat/implied_bounds.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_base.h"

View File

@@ -86,14 +86,15 @@ class SatPostsolver {
int NumClauses() const { return clauses_start_.size(); }
std::vector<Literal> Clause(int i) const {
// TODO(user): we could avoid the copy here, but because clauses_literals_
// is a deque, we do need a special return class and cannot juste use
// is a deque, we do need a special return class and cannot just use
// absl::Span<Literal> for instance.
const int begin = clauses_start_[i];
const int end = i + 1 < clauses_start_.size() ? clauses_start_[i + 1]
: clauses_literals_.size();
const int64_t begin = clauses_start_[i];
const int64_t end = i + 1 < clauses_start_.size()
? clauses_start_[i + 1]
: clauses_literals_.size();
std::vector<Literal> result(clauses_literals_.begin() + begin,
clauses_literals_.begin() + end);
for (int j = 0; j < result.size(); ++j) {
for (int64_t j = 0; j < result.size(); ++j) {
if (result[j] == associated_literal_[i]) {
std::swap(result[0], result[j]);
break;
@@ -118,7 +119,7 @@ class SatPostsolver {
// Stores the arguments of the Add() calls: clauses_start_[i] is the index of
// the first literal of the clause #i in the clauses_literals_ deque.
std::vector<int> clauses_start_;
std::vector<int64_t> clauses_start_;
std::deque<Literal> clauses_literals_;
std::vector<Literal> associated_literal_;