[CP-SAT] improve mod doc; improve precedences in scheduling; speed up circuit data structures

This commit is contained in:
Laurent Perron
2024-07-23 20:01:32 +02:00
parent d0ed31d92e
commit 458e2a1579
11 changed files with 160 additions and 115 deletions

View File

@@ -1223,6 +1223,7 @@ cc_library(
":sat_base",
":sat_solver",
":synchronization",
":util",
"//ortools/base",
"//ortools/base:stl_util",
"//ortools/base:strong_vector",

View File

@@ -56,8 +56,7 @@ CircuitPropagator::CircuitPropagator(const int num_nodes,
values.reserve(num_arcs);
graph_.reserve(num_arcs);
self_arcs_.resize(num_nodes_,
model->GetOrCreate<IntegerEncoder>()->GetFalseLiteral());
self_arcs_.resize(num_nodes_, kFalseLiteralIndex);
for (int arc = 0; arc < num_arcs; ++arc) {
const int head = heads[arc];
const int tail = tails[arc];
@@ -65,7 +64,7 @@ CircuitPropagator::CircuitPropagator(const int num_nodes,
if (assignment_.LiteralIsFalse(literal)) continue;
if (tail == head) {
self_arcs_[tail] = literal;
self_arcs_[tail] = literal.Index();
} else {
graph_[{tail, head}] = literal;
}
@@ -97,7 +96,8 @@ CircuitPropagator::CircuitPropagator(const int num_nodes,
watch_index_to_arcs_.ResetFromFlatMapping(keys, values);
for (int node = 0; node < num_nodes_; ++node) {
if (assignment_.LiteralIsFalse(self_arcs_[node])) {
if (self_arcs_[node] == kFalseLiteralIndex ||
assignment_.LiteralIsFalse(Literal(self_arcs_[node]))) {
// For the multiple_subcircuit_through_zero case, must_be_in_cycle_ will
// be const and only contains zero.
if (node == 0 || !options_.multiple_subcircuit_through_zero) {
@@ -280,7 +280,7 @@ bool CircuitPropagator::Propagate() {
const int node = must_be_in_cycle_[i];
if (!in_current_path_[node]) {
miss_some_nodes = true;
extra_reason = self_arcs_[node].Index();
extra_reason = self_arcs_[node];
break;
}
}
@@ -320,7 +320,10 @@ bool CircuitPropagator::Propagate() {
BooleanVariable variable_with_same_reason = kNoBooleanVariable;
for (int node = 0; node < num_nodes_; ++node) {
if (in_current_path_[node]) continue;
if (assignment_.LiteralIsTrue(self_arcs_[node])) continue;
if (self_arcs_[node] >= 0 &&
assignment_.LiteralIsTrue(Literal(self_arcs_[node]))) {
continue;
}
// This shouldn't happen because ExactlyOnePerRowAndPerColumn() should
// have executed first and propagated self_arcs_[node] to false.
@@ -329,9 +332,12 @@ bool CircuitPropagator::Propagate() {
// We should have detected that above (miss_some_nodes == true). But we
// still need this for corner cases where the same literal is used for
// many arcs, and we just propagated it here.
if (assignment_.LiteralIsFalse(self_arcs_[node])) {
if (self_arcs_[node] == kFalseLiteralIndex ||
assignment_.LiteralIsFalse(Literal(self_arcs_[node]))) {
FillReasonForPath(start_node, trail_->MutableConflict());
trail_->MutableConflict()->push_back(self_arcs_[node]);
if (self_arcs_[node] != kFalseLiteralIndex) {
trail_->MutableConflict()->push_back(Literal(self_arcs_[node]));
}
return false;
}

View File

@@ -37,7 +37,7 @@ namespace sat {
//
// Nodes that are not in the unique allowed sub-circuit must point to themseves.
// A nodes that has no self-arc must thus be inside the sub-circuit. If there is
// no self-arc at all, then this constaint forces the circuit to go through all
// no self-arc at all, then this constraint forces the circuit to go through all
// the nodes. Multi-arcs are NOT supported.
//
// Important: for correctness, this constraint requires that "exactly one"
@@ -87,7 +87,7 @@ class CircuitPropagator : PropagatorInterface, ReversibleInterface {
//
// TODO(user): for large dense graph, using a matrix is faster and uses less
// memory. If the need arise we can have the two implementations.
std::vector<Literal> self_arcs_;
std::vector<LiteralIndex> self_arcs_;
absl::flat_hash_map<std::pair<int, int>, Literal> graph_;
// Data used to interpret the watch indices passed to IncrementalPropagate().

View File

@@ -176,6 +176,9 @@ void LoadVariables(const CpModelProto& model_proto,
// Compute the integer variable references used by the model.
absl::flat_hash_set<int> used_variables;
const bool some_linerization =
m->GetOrCreate<SatParameters>()->linearization_level() > 0;
IndexReferences refs;
for (int c = 0; c < model_proto.constraints_size(); ++c) {
const ConstraintProto& ct = model_proto.constraints(c);
@@ -183,6 +186,20 @@ void LoadVariables(const CpModelProto& model_proto,
for (const int ref : refs.variables) {
used_variables.insert(PositiveRef(ref));
}
// We always add a linear relaxation for circuit/route except for
// linearization level zero.
if (some_linerization) {
if (ct.constraint_case() == ConstraintProto::kCircuit) {
for (const int ref : ct.circuit().literals()) {
used_variables.insert(PositiveRef(ref));
}
} else if (ct.constraint_case() == ConstraintProto::kRoutes) {
for (const int ref : ct.routes().literals()) {
used_variables.insert(PositiveRef(ref));
}
}
}
}
// Add the objectives variables that needs to be referenceable as integer

View File

@@ -820,6 +820,10 @@ class FullProblemSolver : public SubSolver {
}
bool IsDone() override {
// On large problem, deletion can take a while, so is is better to do it
// while waiting for the slower worker to finish.
if (shared_->SearchIsDone()) return true;
return stop_at_first_solution_ &&
shared_->response->first_solution_solvers_should_stop()->load();
}
@@ -983,6 +987,8 @@ class FeasibilityPumpSolver : public SubSolver {
shared_->stat_tables.AddTimingStat(*this);
}
bool IsDone() override { return shared_->SearchIsDone(); }
bool TaskIsAvailable() override {
if (shared_->SearchIsDone()) return false;
absl::MutexLock mutex_lock(&mutex_);

View File

@@ -732,96 +732,111 @@ bool DisjunctiveSimplePrecedences::Propagate() {
return true;
}
bool DisjunctiveSimplePrecedences::Push(TaskTime before, int t) {
const int t_before = before.task_index;
DCHECK_NE(t_before, t);
helper_->ClearReason();
helper_->AddPresenceReason(t_before);
helper_->AddReasonForBeingBefore(t_before, t);
helper_->AddEndMinReason(t_before, before.time);
if (!helper_->IncreaseStartMin(t, before.time)) {
return false;
}
++stats_.num_propagations;
return true;
}
bool DisjunctiveSimplePrecedences::PropagateOneDirection() {
// We will loop in a decreasing way here.
// And add tasks that are present to the task_set_.
absl::Span<const TaskTime> task_by_decreasing_start_max =
helper_->TaskByDecreasingStartMax();
// We just keep amongst all the task before current_task, the one with the
// We just keep amongst all the task before current_end_min, the one with the
// highesh end-min.
TaskTime best_task_before = {-1, kMinIntegerValue};
to_propagate_.clear();
int blocking_task = -1;
processed_.assign(task_by_decreasing_start_max.size(), false);
for (const auto [current_task, current_end_min] :
helper_->TaskByIncreasingEndMin()) {
if (helper_->IsAbsent(current_task)) continue;
// We will loop in an increasing way here and consume task from beginning.
absl::Span<const TaskTime> task_by_increasing_end_min =
helper_->TaskByIncreasingEndMin();
for (; !task_by_increasing_end_min.empty();) {
// Skip absent task.
if (helper_->IsAbsent(task_by_increasing_end_min.front().task_index)) {
task_by_increasing_end_min.remove_prefix(1);
continue;
}
// Consider all task with a start_max < current_end_min.
int blocking_task = -1;
IntegerValue blocking_start_max;
IntegerValue current_end_min = task_by_increasing_end_min.front().time;
for (; true; task_by_decreasing_start_max.remove_suffix(1)) {
if (task_by_decreasing_start_max.empty()) {
// Small optim: this allows to process all remaining task rather than
// looping around are retesting this for all remaining tasks.
current_end_min = kMaxIntegerValue;
break;
}
for (; !task_by_decreasing_start_max.empty();
task_by_decreasing_start_max.remove_suffix(1)) {
const auto [t, start_max] = task_by_decreasing_start_max.back();
if (current_end_min <= start_max) break;
if (!helper_->IsPresent(t)) continue;
// If t has not been processed yet, it has a mandatory part, and we will
// delay all propagation until current_task is equal to this
// "blocking task".
// If t has a mandatory part, and extend further than current_end_min
// then we can push it first. All tasks for which their push is delayed
// are necessarily after this "blocking task".
//
// This idea is introduced in "Linear-Time Filtering Algorithms for the
// Disjunctive Constraints" Hamed Fahimi, Claude-Guy Quimper.
if (!processed_[t]) {
if (blocking_task != -1) {
// We have two blocking tasks, which means they are in conflict.
helper_->ClearReason();
helper_->AddPresenceReason(blocking_task);
helper_->AddPresenceReason(t);
helper_->AddReasonForBeingBefore(blocking_task, t);
helper_->AddReasonForBeingBefore(t, blocking_task);
return helper_->ReportConflict();
}
DCHECK_LT(start_max, helper_->ShiftedStartMin(t) + helper_->SizeMin(t))
<< " task should have mandatory part: "
<< helper_->TaskDebugString(t);
DCHECK(to_propagate_.empty());
const IntegerValue end_min = helper_->EndMin(t);
if (blocking_task == -1 && end_min >= current_end_min) {
DCHECK_LT(start_max, end_min) << " task should have mandatory part: "
<< helper_->TaskDebugString(t);
blocking_task = t;
to_propagate_.push_back(t);
} else {
const IntegerValue end_min = helper_->EndMin(t);
if (end_min > best_task_before.time) {
best_task_before = {t, end_min};
}
}
}
// If we have a blocking task, we delay the propagation until current_task
// is the blocking task.
if (blocking_task != current_task) {
to_propagate_.push_back(current_task);
if (blocking_task != -1) continue;
}
for (const int t : to_propagate_) {
DCHECK_NE(best_task_before.task_index, t);
DCHECK(!processed_[t]);
processed_[t] = true;
if (best_task_before.time > helper_->StartMin(t)) {
// Corner case if a previous push from to_propagate_ caused a subsequent
// task to be absent.
if (helper_->IsAbsent(t)) continue;
const int t_before = best_task_before.task_index;
blocking_start_max = start_max;
current_end_min = end_min;
} else if (blocking_task != -1 && blocking_start_max < end_min) {
// Conflict! the task is after the blocking_task but also before.
helper_->ClearReason();
helper_->AddPresenceReason(t_before);
helper_->AddReasonForBeingBefore(t_before, t);
helper_->AddEndMinReason(t_before, best_task_before.time);
if (!helper_->IncreaseStartMin(t, best_task_before.time)) {
return false;
}
++stats_.num_propagations;
}
if (t == blocking_task) {
blocking_task = -1;
const IntegerValue end_min = helper_->EndMin(t);
if (end_min > best_task_before.time) {
best_task_before = {t, end_min};
}
helper_->AddPresenceReason(blocking_task);
helper_->AddPresenceReason(t);
helper_->AddReasonForBeingBefore(blocking_task, t);
helper_->AddReasonForBeingBefore(t, blocking_task);
return helper_->ReportConflict();
} else if (end_min > best_task_before.time) {
best_task_before = {t, end_min};
}
}
// If we have a blocking task. We need to propagate it first.
if (blocking_task != -1) {
DCHECK(!helper_->IsAbsent(blocking_task));
if (best_task_before.time > helper_->StartMin(blocking_task)) {
if (!Push(best_task_before, blocking_task)) return false;
}
// Update best_task_before (it should likely be the blocking task).
const IntegerValue end_min = helper_->EndMin(blocking_task);
if (end_min > best_task_before.time) {
best_task_before = {blocking_task, end_min};
}
}
// Lets propagate all task after best_task_before.
for (; !task_by_increasing_end_min.empty();
task_by_increasing_end_min.remove_prefix(1)) {
const auto [t, end_min] = task_by_increasing_end_min.front();
if (end_min > current_end_min) break;
if (t == blocking_task) continue; // Already done.
// Lets propagate current_task.
if (best_task_before.time > helper_->StartMin(t)) {
// Corner case if a previous push caused a subsequent task to be absent.
if (helper_->IsAbsent(t)) continue;
if (!Push(best_task_before, t)) return false;
}
}
to_propagate_.clear();
}
return true;
}

View File

@@ -213,17 +213,13 @@ class DisjunctiveSimplePrecedences : public PropagatorInterface {
public:
explicit DisjunctiveSimplePrecedences(SchedulingConstraintHelper* helper,
Model* model = nullptr)
: helper_(helper), stats_("DisjunctiveSimplePrecedences", model) {
to_propagate_.ClearAndReserve(helper->NumTasks());
}
: helper_(helper), stats_("DisjunctiveSimplePrecedences", model) {}
bool Propagate() final;
int RegisterWith(GenericLiteralWatcher* watcher);
private:
bool PropagateOneDirection();
std::vector<bool> processed_;
FixedCapacityVector<int> to_propagate_;
bool Push(TaskTime before, int t);
SchedulingConstraintHelper* helper_;
PropagationStatistics stats_;

View File

@@ -500,11 +500,6 @@ void AppendCircuitRelaxation(const ConstraintProto& ct, Model* model,
const Literal arc = mapping->Literal(ct.circuit().literals(i));
const int tail = ct.circuit().tails(i);
const int head = ct.circuit().heads(i);
// Make sure this literal has a view.
if (!mapping->IsInteger(PositiveRef(ct.circuit().literals(i)))) {
CreateNewIntegerVariableFromLiteral(arc, model);
}
outgoing_arc_constraints[tail].push_back(arc);
incoming_arc_constraints[head].push_back(arc);
}
@@ -545,11 +540,6 @@ void AppendRoutesRelaxation(const ConstraintProto& ct, Model* model,
const Literal arc = mapping->Literal(ct.routes().literals(i));
const int tail = ct.routes().tails(i);
const int head = ct.routes().heads(i);
// Make sure this literal has a view.
if (!mapping->IsInteger(PositiveRef(ct.routes().literals(i)))) {
CreateNewIntegerVariableFromLiteral(arc, model);
}
outgoing_arc_constraints[tail].push_back(arc);
incoming_arc_constraints[head].push_back(arc);
}

View File

@@ -17,6 +17,7 @@
#include <algorithm>
#include <deque>
#include <memory>
#include <string>
#include <utility>
#include <vector>
@@ -41,6 +42,7 @@
#include "ortools/sat/sat_base.h"
#include "ortools/sat/sat_solver.h"
#include "ortools/sat/synchronization.h"
#include "ortools/sat/util.h"
#include "ortools/util/bitset.h"
#include "ortools/util/logging.h"
#include "ortools/util/strong_integers.h"
@@ -1117,13 +1119,7 @@ void GreaterThanAtLeastOneOfDetector::Add(Literal lit, LinearTerm a,
r.b.coeff = -r.b.coeff;
}
const int index = relations_.size();
relations_.push_back(std::move(r));
if (lit.Index() >= lit_to_relations_.size()) {
lit_to_relations_.resize(lit.Index() + 1);
}
lit_to_relations_[lit.Index()].push_back(index);
}
bool GreaterThanAtLeastOneOfDetector::AddRelationFromIndices(
@@ -1192,8 +1188,8 @@ int GreaterThanAtLeastOneOfDetector::
// Collect all relations impacted by this clause.
std::vector<std::pair<IntegerVariable, int>> infos;
for (const Literal l : clause) {
if (l.Index() >= lit_to_relations_.size()) continue;
for (const int index : lit_to_relations_[l.Index()]) {
if (l.Index() >= lit_to_relations_->size()) continue;
for (const int index : (*lit_to_relations_)[l.Index()]) {
const Relation& r = relations_[index];
if (r.a.var != kNoIntegerVariable && IntTypeAbs(r.a.coeff) == 1) {
infos.push_back({r.a.var, index});
@@ -1324,6 +1320,19 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints(
SOLVER_LOG(logger, "[Precedences] num_relations=", relations_.size(),
" num_clauses=", clauses->AllClausesInCreationOrder().size());
// Initialize lit_to_relations_.
{
std::vector<LiteralIndex> keys;
const int num_relations = relations_.size();
keys.reserve(num_relations);
for (int i = 0; i < num_relations; ++i) {
keys.push_back(relations_[i].enforcement.Index());
}
lit_to_relations_ =
std::make_unique<CompactVectorVector<LiteralIndex, int>>();
lit_to_relations_->ResetFromFlatMapping(keys, IdentityMap<int>());
}
// We have two possible approaches. For now, we prefer the first one except if
// there is too many clauses in the problem.
//
@@ -1374,9 +1383,8 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints(
}
// Release the memory, it is not longer needed.
lit_to_relations_.reset(nullptr);
gtl::STLClearObject(&relations_);
gtl::STLClearObject(&lit_to_relations_);
return num_added_constraints;
}

View File

@@ -514,9 +514,9 @@ class GreaterThanAtLeastOneOfDetector {
IntegerValue lhs;
IntegerValue rhs;
};
std::vector<Relation> relations_;
util_intops::StrongVector<LiteralIndex, std::vector<int>> lit_to_relations_;
std::unique_ptr<CompactVectorVector<LiteralIndex, int>> lit_to_relations_;
};
// =============================================================================

View File

@@ -2205,18 +2205,24 @@ class CpModel:
def add_modulo_equality(
self, target: LinearExprT, expr: LinearExprT, mod: LinearExprT
) -> Constraint:
"""Adds `target = expr % mod.
"""Adds `target = expr % mod`.
It uses the C convention, that is the result is the remainder of the
integral divisiion rounded towards 0.
integral division rounded towards 0.
For example:
* 10 % 3 = 1
* -10 % 3 = -1
* 10 % -3 = 1
* -10 % -3 = -1
Args:
target: the target expression.
expr: the expression to compute the modulo of.
mod: the modulus expression.
target: the target expression.
expr: the expression to compute the modulo of.
mod: the modulus expression.
Returns:
A `Constraint` object.
A `Constraint` object.
"""
ct = Constraint(self)
model_ct = self.__model.constraints[ct.index]