diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index b9b0fa7d2f..b561fb7cb1 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -1223,6 +1223,7 @@ cc_library( ":sat_base", ":sat_solver", ":synchronization", + ":util", "//ortools/base", "//ortools/base:stl_util", "//ortools/base:strong_vector", diff --git a/ortools/sat/circuit.cc b/ortools/sat/circuit.cc index 0aa0fe8544..74619e2048 100644 --- a/ortools/sat/circuit.cc +++ b/ortools/sat/circuit.cc @@ -56,8 +56,7 @@ CircuitPropagator::CircuitPropagator(const int num_nodes, values.reserve(num_arcs); graph_.reserve(num_arcs); - self_arcs_.resize(num_nodes_, - model->GetOrCreate()->GetFalseLiteral()); + self_arcs_.resize(num_nodes_, kFalseLiteralIndex); for (int arc = 0; arc < num_arcs; ++arc) { const int head = heads[arc]; const int tail = tails[arc]; @@ -65,7 +64,7 @@ CircuitPropagator::CircuitPropagator(const int num_nodes, if (assignment_.LiteralIsFalse(literal)) continue; if (tail == head) { - self_arcs_[tail] = literal; + self_arcs_[tail] = literal.Index(); } else { graph_[{tail, head}] = literal; } @@ -97,7 +96,8 @@ CircuitPropagator::CircuitPropagator(const int num_nodes, watch_index_to_arcs_.ResetFromFlatMapping(keys, values); for (int node = 0; node < num_nodes_; ++node) { - if (assignment_.LiteralIsFalse(self_arcs_[node])) { + if (self_arcs_[node] == kFalseLiteralIndex || + assignment_.LiteralIsFalse(Literal(self_arcs_[node]))) { // For the multiple_subcircuit_through_zero case, must_be_in_cycle_ will // be const and only contains zero. if (node == 0 || !options_.multiple_subcircuit_through_zero) { @@ -280,7 +280,7 @@ bool CircuitPropagator::Propagate() { const int node = must_be_in_cycle_[i]; if (!in_current_path_[node]) { miss_some_nodes = true; - extra_reason = self_arcs_[node].Index(); + extra_reason = self_arcs_[node]; break; } } @@ -320,7 +320,10 @@ bool CircuitPropagator::Propagate() { BooleanVariable variable_with_same_reason = kNoBooleanVariable; for (int node = 0; node < num_nodes_; ++node) { if (in_current_path_[node]) continue; - if (assignment_.LiteralIsTrue(self_arcs_[node])) continue; + if (self_arcs_[node] >= 0 && + assignment_.LiteralIsTrue(Literal(self_arcs_[node]))) { + continue; + } // This shouldn't happen because ExactlyOnePerRowAndPerColumn() should // have executed first and propagated self_arcs_[node] to false. @@ -329,9 +332,12 @@ bool CircuitPropagator::Propagate() { // We should have detected that above (miss_some_nodes == true). But we // still need this for corner cases where the same literal is used for // many arcs, and we just propagated it here. - if (assignment_.LiteralIsFalse(self_arcs_[node])) { + if (self_arcs_[node] == kFalseLiteralIndex || + assignment_.LiteralIsFalse(Literal(self_arcs_[node]))) { FillReasonForPath(start_node, trail_->MutableConflict()); - trail_->MutableConflict()->push_back(self_arcs_[node]); + if (self_arcs_[node] != kFalseLiteralIndex) { + trail_->MutableConflict()->push_back(Literal(self_arcs_[node])); + } return false; } diff --git a/ortools/sat/circuit.h b/ortools/sat/circuit.h index 27814484a4..b5f3c75cb2 100644 --- a/ortools/sat/circuit.h +++ b/ortools/sat/circuit.h @@ -37,7 +37,7 @@ namespace sat { // // Nodes that are not in the unique allowed sub-circuit must point to themseves. // A nodes that has no self-arc must thus be inside the sub-circuit. If there is -// no self-arc at all, then this constaint forces the circuit to go through all +// no self-arc at all, then this constraint forces the circuit to go through all // the nodes. Multi-arcs are NOT supported. // // Important: for correctness, this constraint requires that "exactly one" @@ -87,7 +87,7 @@ class CircuitPropagator : PropagatorInterface, ReversibleInterface { // // TODO(user): for large dense graph, using a matrix is faster and uses less // memory. If the need arise we can have the two implementations. - std::vector self_arcs_; + std::vector self_arcs_; absl::flat_hash_map, Literal> graph_; // Data used to interpret the watch indices passed to IncrementalPropagate(). diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index ae39a57830..f4d3a3728b 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -176,6 +176,9 @@ void LoadVariables(const CpModelProto& model_proto, // Compute the integer variable references used by the model. absl::flat_hash_set used_variables; + const bool some_linerization = + m->GetOrCreate()->linearization_level() > 0; + IndexReferences refs; for (int c = 0; c < model_proto.constraints_size(); ++c) { const ConstraintProto& ct = model_proto.constraints(c); @@ -183,6 +186,20 @@ void LoadVariables(const CpModelProto& model_proto, for (const int ref : refs.variables) { used_variables.insert(PositiveRef(ref)); } + + // We always add a linear relaxation for circuit/route except for + // linearization level zero. + if (some_linerization) { + if (ct.constraint_case() == ConstraintProto::kCircuit) { + for (const int ref : ct.circuit().literals()) { + used_variables.insert(PositiveRef(ref)); + } + } else if (ct.constraint_case() == ConstraintProto::kRoutes) { + for (const int ref : ct.routes().literals()) { + used_variables.insert(PositiveRef(ref)); + } + } + } } // Add the objectives variables that needs to be referenceable as integer diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index b959fe72f8..091309f299 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -820,6 +820,10 @@ class FullProblemSolver : public SubSolver { } bool IsDone() override { + // On large problem, deletion can take a while, so is is better to do it + // while waiting for the slower worker to finish. + if (shared_->SearchIsDone()) return true; + return stop_at_first_solution_ && shared_->response->first_solution_solvers_should_stop()->load(); } @@ -983,6 +987,8 @@ class FeasibilityPumpSolver : public SubSolver { shared_->stat_tables.AddTimingStat(*this); } + bool IsDone() override { return shared_->SearchIsDone(); } + bool TaskIsAvailable() override { if (shared_->SearchIsDone()) return false; absl::MutexLock mutex_lock(&mutex_); diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 777cc5803f..32c312af2b 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -732,96 +732,111 @@ bool DisjunctiveSimplePrecedences::Propagate() { return true; } +bool DisjunctiveSimplePrecedences::Push(TaskTime before, int t) { + const int t_before = before.task_index; + DCHECK_NE(t_before, t); + helper_->ClearReason(); + helper_->AddPresenceReason(t_before); + helper_->AddReasonForBeingBefore(t_before, t); + helper_->AddEndMinReason(t_before, before.time); + if (!helper_->IncreaseStartMin(t, before.time)) { + return false; + } + ++stats_.num_propagations; + return true; +} + bool DisjunctiveSimplePrecedences::PropagateOneDirection() { // We will loop in a decreasing way here. // And add tasks that are present to the task_set_. absl::Span task_by_decreasing_start_max = helper_->TaskByDecreasingStartMax(); - // We just keep amongst all the task before current_task, the one with the + // We just keep amongst all the task before current_end_min, the one with the // highesh end-min. TaskTime best_task_before = {-1, kMinIntegerValue}; - to_propagate_.clear(); - int blocking_task = -1; - processed_.assign(task_by_decreasing_start_max.size(), false); - for (const auto [current_task, current_end_min] : - helper_->TaskByIncreasingEndMin()) { - if (helper_->IsAbsent(current_task)) continue; + // We will loop in an increasing way here and consume task from beginning. + absl::Span task_by_increasing_end_min = + helper_->TaskByIncreasingEndMin(); + + for (; !task_by_increasing_end_min.empty();) { + // Skip absent task. + if (helper_->IsAbsent(task_by_increasing_end_min.front().task_index)) { + task_by_increasing_end_min.remove_prefix(1); + continue; + } + + // Consider all task with a start_max < current_end_min. + int blocking_task = -1; + IntegerValue blocking_start_max; + IntegerValue current_end_min = task_by_increasing_end_min.front().time; + for (; true; task_by_decreasing_start_max.remove_suffix(1)) { + if (task_by_decreasing_start_max.empty()) { + // Small optim: this allows to process all remaining task rather than + // looping around are retesting this for all remaining tasks. + current_end_min = kMaxIntegerValue; + break; + } - for (; !task_by_decreasing_start_max.empty(); - task_by_decreasing_start_max.remove_suffix(1)) { const auto [t, start_max] = task_by_decreasing_start_max.back(); if (current_end_min <= start_max) break; if (!helper_->IsPresent(t)) continue; - // If t has not been processed yet, it has a mandatory part, and we will - // delay all propagation until current_task is equal to this - // "blocking task". + // If t has a mandatory part, and extend further than current_end_min + // then we can push it first. All tasks for which their push is delayed + // are necessarily after this "blocking task". // // This idea is introduced in "Linear-Time Filtering Algorithms for the // Disjunctive Constraints" Hamed Fahimi, Claude-Guy Quimper. - if (!processed_[t]) { - if (blocking_task != -1) { - // We have two blocking tasks, which means they are in conflict. - helper_->ClearReason(); - helper_->AddPresenceReason(blocking_task); - helper_->AddPresenceReason(t); - helper_->AddReasonForBeingBefore(blocking_task, t); - helper_->AddReasonForBeingBefore(t, blocking_task); - return helper_->ReportConflict(); - } - DCHECK_LT(start_max, helper_->ShiftedStartMin(t) + helper_->SizeMin(t)) - << " task should have mandatory part: " - << helper_->TaskDebugString(t); - DCHECK(to_propagate_.empty()); + const IntegerValue end_min = helper_->EndMin(t); + if (blocking_task == -1 && end_min >= current_end_min) { + DCHECK_LT(start_max, end_min) << " task should have mandatory part: " + << helper_->TaskDebugString(t); blocking_task = t; - to_propagate_.push_back(t); - } else { - const IntegerValue end_min = helper_->EndMin(t); - if (end_min > best_task_before.time) { - best_task_before = {t, end_min}; - } - } - } - - // If we have a blocking task, we delay the propagation until current_task - // is the blocking task. - if (blocking_task != current_task) { - to_propagate_.push_back(current_task); - if (blocking_task != -1) continue; - } - - for (const int t : to_propagate_) { - DCHECK_NE(best_task_before.task_index, t); - DCHECK(!processed_[t]); - processed_[t] = true; - - if (best_task_before.time > helper_->StartMin(t)) { - // Corner case if a previous push from to_propagate_ caused a subsequent - // task to be absent. - if (helper_->IsAbsent(t)) continue; - - const int t_before = best_task_before.task_index; + blocking_start_max = start_max; + current_end_min = end_min; + } else if (blocking_task != -1 && blocking_start_max < end_min) { + // Conflict! the task is after the blocking_task but also before. helper_->ClearReason(); - helper_->AddPresenceReason(t_before); - helper_->AddReasonForBeingBefore(t_before, t); - helper_->AddEndMinReason(t_before, best_task_before.time); - if (!helper_->IncreaseStartMin(t, best_task_before.time)) { - return false; - } - ++stats_.num_propagations; - } - - if (t == blocking_task) { - blocking_task = -1; - const IntegerValue end_min = helper_->EndMin(t); - if (end_min > best_task_before.time) { - best_task_before = {t, end_min}; - } + helper_->AddPresenceReason(blocking_task); + helper_->AddPresenceReason(t); + helper_->AddReasonForBeingBefore(blocking_task, t); + helper_->AddReasonForBeingBefore(t, blocking_task); + return helper_->ReportConflict(); + } else if (end_min > best_task_before.time) { + best_task_before = {t, end_min}; + } + } + + // If we have a blocking task. We need to propagate it first. + if (blocking_task != -1) { + DCHECK(!helper_->IsAbsent(blocking_task)); + if (best_task_before.time > helper_->StartMin(blocking_task)) { + if (!Push(best_task_before, blocking_task)) return false; + } + + // Update best_task_before (it should likely be the blocking task). + const IntegerValue end_min = helper_->EndMin(blocking_task); + if (end_min > best_task_before.time) { + best_task_before = {blocking_task, end_min}; + } + } + + // Lets propagate all task after best_task_before. + for (; !task_by_increasing_end_min.empty(); + task_by_increasing_end_min.remove_prefix(1)) { + const auto [t, end_min] = task_by_increasing_end_min.front(); + if (end_min > current_end_min) break; + if (t == blocking_task) continue; // Already done. + + // Lets propagate current_task. + if (best_task_before.time > helper_->StartMin(t)) { + // Corner case if a previous push caused a subsequent task to be absent. + if (helper_->IsAbsent(t)) continue; + if (!Push(best_task_before, t)) return false; } } - to_propagate_.clear(); } return true; } diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index 7c78cf2b16..f27232866f 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -213,17 +213,13 @@ class DisjunctiveSimplePrecedences : public PropagatorInterface { public: explicit DisjunctiveSimplePrecedences(SchedulingConstraintHelper* helper, Model* model = nullptr) - : helper_(helper), stats_("DisjunctiveSimplePrecedences", model) { - to_propagate_.ClearAndReserve(helper->NumTasks()); - } + : helper_(helper), stats_("DisjunctiveSimplePrecedences", model) {} bool Propagate() final; int RegisterWith(GenericLiteralWatcher* watcher); private: bool PropagateOneDirection(); - - std::vector processed_; - FixedCapacityVector to_propagate_; + bool Push(TaskTime before, int t); SchedulingConstraintHelper* helper_; PropagationStatistics stats_; diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 7d9da1d2ce..5175b818c7 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -500,11 +500,6 @@ void AppendCircuitRelaxation(const ConstraintProto& ct, Model* model, const Literal arc = mapping->Literal(ct.circuit().literals(i)); const int tail = ct.circuit().tails(i); const int head = ct.circuit().heads(i); - - // Make sure this literal has a view. - if (!mapping->IsInteger(PositiveRef(ct.circuit().literals(i)))) { - CreateNewIntegerVariableFromLiteral(arc, model); - } outgoing_arc_constraints[tail].push_back(arc); incoming_arc_constraints[head].push_back(arc); } @@ -545,11 +540,6 @@ void AppendRoutesRelaxation(const ConstraintProto& ct, Model* model, const Literal arc = mapping->Literal(ct.routes().literals(i)); const int tail = ct.routes().tails(i); const int head = ct.routes().heads(i); - - // Make sure this literal has a view. - if (!mapping->IsInteger(PositiveRef(ct.routes().literals(i)))) { - CreateNewIntegerVariableFromLiteral(arc, model); - } outgoing_arc_constraints[tail].push_back(arc); incoming_arc_constraints[head].push_back(arc); } diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index 2fc99ce49b..34b1559a0d 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" #include "ortools/util/bitset.h" #include "ortools/util/logging.h" #include "ortools/util/strong_integers.h" @@ -1117,13 +1119,7 @@ void GreaterThanAtLeastOneOfDetector::Add(Literal lit, LinearTerm a, r.b.coeff = -r.b.coeff; } - const int index = relations_.size(); relations_.push_back(std::move(r)); - - if (lit.Index() >= lit_to_relations_.size()) { - lit_to_relations_.resize(lit.Index() + 1); - } - lit_to_relations_[lit.Index()].push_back(index); } bool GreaterThanAtLeastOneOfDetector::AddRelationFromIndices( @@ -1192,8 +1188,8 @@ int GreaterThanAtLeastOneOfDetector:: // Collect all relations impacted by this clause. std::vector> infos; for (const Literal l : clause) { - if (l.Index() >= lit_to_relations_.size()) continue; - for (const int index : lit_to_relations_[l.Index()]) { + if (l.Index() >= lit_to_relations_->size()) continue; + for (const int index : (*lit_to_relations_)[l.Index()]) { const Relation& r = relations_[index]; if (r.a.var != kNoIntegerVariable && IntTypeAbs(r.a.coeff) == 1) { infos.push_back({r.a.var, index}); @@ -1324,6 +1320,19 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints( SOLVER_LOG(logger, "[Precedences] num_relations=", relations_.size(), " num_clauses=", clauses->AllClausesInCreationOrder().size()); + // Initialize lit_to_relations_. + { + std::vector keys; + const int num_relations = relations_.size(); + keys.reserve(num_relations); + for (int i = 0; i < num_relations; ++i) { + keys.push_back(relations_[i].enforcement.Index()); + } + lit_to_relations_ = + std::make_unique>(); + lit_to_relations_->ResetFromFlatMapping(keys, IdentityMap()); + } + // We have two possible approaches. For now, we prefer the first one except if // there is too many clauses in the problem. // @@ -1374,9 +1383,8 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints( } // Release the memory, it is not longer needed. + lit_to_relations_.reset(nullptr); gtl::STLClearObject(&relations_); - gtl::STLClearObject(&lit_to_relations_); - return num_added_constraints; } diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 2a467e77a9..2e984f41cf 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -514,9 +514,9 @@ class GreaterThanAtLeastOneOfDetector { IntegerValue lhs; IntegerValue rhs; }; - std::vector relations_; - util_intops::StrongVector> lit_to_relations_; + + std::unique_ptr> lit_to_relations_; }; // ============================================================================= diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 27513448c8..9cf9443f96 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -2205,18 +2205,24 @@ class CpModel: def add_modulo_equality( self, target: LinearExprT, expr: LinearExprT, mod: LinearExprT ) -> Constraint: - """Adds `target = expr % mod. + """Adds `target = expr % mod`. It uses the C convention, that is the result is the remainder of the - integral divisiion rounded towards 0. - + integral division rounded towards 0. + + For example: + * 10 % 3 = 1 + * -10 % 3 = -1 + * 10 % -3 = 1 + * -10 % -3 = -1 + Args: - target: the target expression. - expr: the expression to compute the modulo of. - mod: the modulus expression. - + target: the target expression. + expr: the expression to compute the modulo of. + mod: the modulus expression. + Returns: - A `Constraint` object. + A `Constraint` object. """ ct = Constraint(self) model_ct = self.__model.constraints[ct.index]