23#include "absl/container/btree_map.h"
24#include "absl/container/flat_hash_map.h"
25#include "absl/container/flat_hash_set.h"
26#include "absl/meta/type_traits.h"
27#include "absl/strings/str_cat.h"
28#include "absl/types/span.h"
33#include "ortools/sat/cp_model.pb.h"
36#include "ortools/sat/sat_parameters.pb.h"
46void ExpandReservoir(ConstraintProto*
ct, PresolveContext*
context) {
47 if (
ct->reservoir().min_level() >
ct->reservoir().max_level()) {
48 VLOG(1) <<
"Empty level domain in reservoir constraint.";
49 return (
void)
context->NotifyThatModelIsUnsat();
52 const ReservoirConstraintProto& reservoir =
ct->reservoir();
53 const int num_events = reservoir.time_exprs_size();
55 const int true_literal =
context->GetOrCreateConstantVar(1);
57 const auto is_active_literal = [&reservoir, true_literal](
int index) {
58 if (reservoir.active_literals_size() == 0)
return true_literal;
59 return reservoir.active_literals(
index);
62 int num_positives = 0;
63 int num_negatives = 0;
64 for (
const int64_t
demand : reservoir.level_changes()) {
72 absl::flat_hash_map<std::pair<int, int>,
int> precedence_cache;
74 if (num_positives > 0 && num_negatives > 0) {
76 for (
int i = 0; i < num_events - 1; ++i) {
77 const int active_i = is_active_literal(i);
78 if (
context->LiteralIsFalse(active_i))
continue;
79 const LinearExpressionProto& time_i = reservoir.time_exprs(i);
81 for (
int j = i + 1; j < num_events; ++j) {
82 const int active_j = is_active_literal(j);
83 if (
context->LiteralIsFalse(active_j))
continue;
84 const LinearExpressionProto& time_j = reservoir.time_exprs(j);
86 const int i_lesseq_j =
context->GetOrCreateReifiedPrecedenceLiteral(
87 time_i, time_j, active_i, active_j);
88 context->working_model->mutable_variables(i_lesseq_j)
89 ->set_name(absl::StrCat(i,
" before ", j));
90 precedence_cache[{i, j}] = i_lesseq_j;
91 const int j_lesseq_i =
context->GetOrCreateReifiedPrecedenceLiteral(
92 time_j, time_i, active_j, active_i);
93 context->working_model->mutable_variables(j_lesseq_i)
94 ->set_name(absl::StrCat(j,
" before ", i));
95 precedence_cache[{j, i}] = j_lesseq_i;
103 for (
int i = 0; i < num_events; ++i) {
104 const int active_i = is_active_literal(i);
105 if (
context->LiteralIsFalse(active_i))
continue;
108 ConstraintProto*
const level =
context->working_model->add_constraints();
109 level->add_enforcement_literal(active_i);
113 for (
int j = 0; j < num_events; ++j) {
114 if (i == j)
continue;
115 const int active_j = is_active_literal(j);
116 if (
context->LiteralIsFalse(active_j))
continue;
118 const auto prec_it = precedence_cache.find({j, i});
119 CHECK(prec_it != precedence_cache.end());
120 const int prec_lit = prec_it->second;
121 const int64_t
demand = reservoir.level_changes(j);
123 level->mutable_linear()->add_vars(prec_lit);
124 level->mutable_linear()->add_coeffs(
demand);
126 level->mutable_linear()->add_vars(prec_lit);
127 level->mutable_linear()->add_coeffs(-
demand);
133 const int64_t demand_i = reservoir.level_changes(i);
134 level->mutable_linear()->add_domain(
135 CapAdd(
CapSub(reservoir.min_level(), demand_i), offset));
136 level->mutable_linear()->add_domain(
137 CapAdd(
CapSub(reservoir.max_level(), demand_i), offset));
143 context->working_model->add_constraints()->mutable_linear();
144 for (
int i = 0; i < num_events; ++i) {
145 sum->add_vars(is_active_literal(i));
146 sum->add_coeffs(reservoir.level_changes(i));
148 sum->add_domain(reservoir.min_level());
149 sum->add_domain(reservoir.max_level());
153 context->UpdateRuleStats(
"reservoir: expanded");
156void ExpandIntMod(ConstraintProto*
ct, PresolveContext*
context) {
157 const LinearArgumentProto& int_mod =
ct->int_mod();
158 const LinearExpressionProto& mod_expr = int_mod.exprs(1);
159 if (
context->IsFixed(mod_expr))
return;
161 const LinearExpressionProto& expr = int_mod.exprs(0);
162 const LinearExpressionProto& target_expr = int_mod.target();
165 if (!
context->IntersectDomainWith(
166 target_expr,
context->DomainSuperSetOf(expr).PositiveModuloBySuperset(
167 context->DomainSuperSetOf(mod_expr)))) {
172 auto new_enforced_constraint = [&]() {
173 ConstraintProto* new_ct =
context->working_model->add_constraints();
174 *new_ct->mutable_enforcement_literal() =
ct->enforcement_literal();
179 const int div_var =
context->NewIntVar(
180 context->DomainSuperSetOf(expr).PositiveDivisionBySuperset(
181 context->DomainSuperSetOf(mod_expr)));
182 LinearExpressionProto div_expr;
183 div_expr.add_vars(div_var);
184 div_expr.add_coeffs(1);
186 LinearArgumentProto*
const div_proto =
187 new_enforced_constraint()->mutable_int_div();
188 *div_proto->mutable_target() = div_expr;
189 *div_proto->add_exprs() = expr;
190 *div_proto->add_exprs() = mod_expr;
193 const Domain prod_domain =
195 .ContinuousMultiplicationBy(
context->DomainSuperSetOf(mod_expr))
196 .IntersectionWith(
context->DomainSuperSetOf(expr).AdditionWith(
197 context->DomainSuperSetOf(target_expr).Negation()));
198 const int prod_var =
context->NewIntVar(prod_domain);
199 LinearExpressionProto prod_expr;
200 prod_expr.add_vars(prod_var);
201 prod_expr.add_coeffs(1);
203 LinearArgumentProto*
const int_prod =
204 new_enforced_constraint()->mutable_int_prod();
205 *int_prod->mutable_target() = prod_expr;
206 *int_prod->add_exprs() = div_expr;
207 *int_prod->add_exprs() = mod_expr;
210 LinearConstraintProto*
const lin =
211 new_enforced_constraint()->mutable_linear();
219 context->UpdateRuleStats(
"int_mod: expanded");
223void ExpandIntProdWithBoolean(
int bool_ref,
224 const LinearExpressionProto& int_expr,
225 const LinearExpressionProto& product_expr,
227 ConstraintProto*
const one =
context->working_model->add_constraints();
228 one->add_enforcement_literal(bool_ref);
229 one->mutable_linear()->add_domain(0);
230 one->mutable_linear()->add_domain(0);
233 one->mutable_linear());
235 ConstraintProto*
const zero =
context->working_model->add_constraints();
236 zero->add_enforcement_literal(
NegatedRef(bool_ref));
237 zero->mutable_linear()->add_domain(0);
238 zero->mutable_linear()->add_domain(0);
240 zero->mutable_linear());
243void ExpandIntProd(ConstraintProto*
ct, PresolveContext*
context) {
244 const LinearArgumentProto& int_prod =
ct->int_prod();
245 if (int_prod.exprs_size() != 2)
return;
246 const LinearExpressionProto&
a = int_prod.exprs(0);
247 const LinearExpressionProto&
b = int_prod.exprs(1);
248 const LinearExpressionProto& p = int_prod.target();
255 if (a_is_literal && !b_is_literal) {
258 context->UpdateRuleStats(
"int_prod: expanded product with Boolean var");
259 }
else if (b_is_literal) {
262 context->UpdateRuleStats(
"int_prod: expanded product with Boolean var");
266void ExpandInverse(ConstraintProto*
ct, PresolveContext*
context) {
267 const auto& f_direct =
ct->inverse().f_direct();
268 const auto& f_inverse =
ct->inverse().f_inverse();
269 const int n = f_direct.size();
278 absl::flat_hash_set<int> used_variables;
279 for (
const int ref : f_direct) {
281 if (!
context->IntersectDomainWith(ref, Domain(0, n - 1))) {
282 VLOG(1) <<
"Empty domain for a variable in ExpandInverse()";
286 for (
const int ref : f_inverse) {
288 if (!
context->IntersectDomainWith(ref, Domain(0, n - 1))) {
289 VLOG(1) <<
"Empty domain for a variable in ExpandInverse()";
296 if (used_variables.size() != 2 * n) {
297 for (
int i = 0; i < n; ++i) {
298 for (
int j = 0; j < n; ++j) {
303 if (i == j)
continue;
304 if (!
context->IntersectDomainWith(
314 std::vector<int64_t> possible_values;
317 const auto filter_inverse_domain =
318 [
context, n, &possible_values](
const auto& direct,
const auto& inverse) {
320 for (
int i = 0; i < n; ++i) {
321 possible_values.clear();
322 const Domain domain =
context->DomainOf(direct[i]);
323 bool removed_value =
false;
324 for (
const int64_t j : domain.Values()) {
325 if (
context->DomainOf(inverse[j]).Contains(i)) {
326 possible_values.push_back(j);
328 removed_value =
true;
332 if (!
context->IntersectDomainWith(
334 VLOG(1) <<
"Empty domain for a variable in ExpandInverse()";
344 if (!filter_inverse_domain(f_direct, f_inverse))
return;
345 if (!filter_inverse_domain(f_inverse, f_direct))
return;
351 for (
int i = 0; i < n; ++i) {
352 const int f_i = f_direct[i];
353 for (
const int64_t j :
context->DomainOf(f_i).Values()) {
355 const int r_j = f_inverse[j];
357 if (
context->HasVarValueEncoding(r_j, i, &r_j_i)) {
358 context->InsertVarValueEncoding(r_j_i, f_i, j);
360 const int f_i_j =
context->GetOrCreateVarValueEncoding(f_i, j);
361 context->InsertVarValueEncoding(f_i_j, r_j, i);
367 context->UpdateRuleStats(
"inverse: expanded");
371void ExpandElementWithTargetEqualIndex(ConstraintProto*
ct,
373 const ElementConstraintProto& element =
ct->element();
374 DCHECK_EQ(element.index(), element.target());
376 const int index_ref = element.index();
377 std::vector<int64_t> valid_indices;
378 for (
const int64_t v :
context->DomainOf(index_ref).Values()) {
379 if (!
context->DomainContains(element.vars(v), v))
continue;
380 valid_indices.push_back(v);
382 if (valid_indices.size() <
context->DomainOf(index_ref).Size()) {
383 if (!
context->IntersectDomainWith(index_ref,
385 VLOG(1) <<
"No compatible variable domains in "
386 "ExpandElementWithTargetEqualIndex()";
389 context->UpdateRuleStats(
"element: reduced index domain");
392 for (
const int64_t v :
context->DomainOf(index_ref).Values()) {
393 const int var = element.vars(v);
396 context->GetOrCreateVarValueEncoding(index_ref, v),
var, Domain(v));
399 "element: expanded with special case target = index");
404void ExpandConstantArrayElement(ConstraintProto*
ct, PresolveContext*
context) {
405 const ElementConstraintProto& element =
ct->element();
406 const int index_ref = element.index();
407 const int target_ref = element.target();
410 const Domain index_domain =
context->DomainOf(index_ref);
411 const Domain target_domain =
context->DomainOf(target_ref);
418 absl::flat_hash_map<int64_t, BoolArgumentProto*> supports;
420 absl::flat_hash_map<int64_t, int> constant_var_values_usage;
421 for (
const int64_t v : index_domain.Values()) {
424 if (++constant_var_values_usage[
value] == 2) {
426 BoolArgumentProto*
const support =
427 context->working_model->add_constraints()->mutable_bool_or();
428 const int target_literal =
429 context->GetOrCreateVarValueEncoding(target_ref,
value);
430 support->add_literals(
NegatedRef(target_literal));
431 supports[
value] = support;
440 context->working_model->add_constraints()->mutable_exactly_one();
441 for (
const int64_t v : index_domain.Values()) {
442 const int index_literal =
443 context->GetOrCreateVarValueEncoding(index_ref, v);
444 exactly_one->add_literals(index_literal);
447 const auto& it = supports.find(
value);
448 if (it != supports.end()) {
451 const int target_literal =
452 context->GetOrCreateVarValueEncoding(target_ref,
value);
453 context->AddImplication(index_literal, target_literal);
454 it->second->add_literals(index_literal);
457 context->InsertVarValueEncoding(index_literal, target_ref,
value);
462 context->UpdateRuleStats(
"element: expanded value element");
467void ExpandVariableElement(ConstraintProto*
ct, PresolveContext*
context) {
468 const ElementConstraintProto& element =
ct->element();
469 const int index_ref = element.index();
470 const int target_ref = element.target();
471 const Domain index_domain =
context->DomainOf(index_ref);
473 BoolArgumentProto* bool_or =
474 context->working_model->add_constraints()->mutable_bool_or();
476 for (
const int64_t v : index_domain.Values()) {
477 const int var = element.vars(v);
478 const Domain var_domain =
context->DomainOf(
var);
479 const int index_lit =
context->GetOrCreateVarValueEncoding(index_ref, v);
480 bool_or->add_literals(index_lit);
482 if (var_domain.IsFixed()) {
483 context->AddImplyInDomain(index_lit, target_ref, var_domain);
485 ConstraintProto*
const ct =
context->working_model->add_constraints();
486 ct->add_enforcement_literal(index_lit);
487 ct->mutable_linear()->add_vars(
var);
488 ct->mutable_linear()->add_coeffs(1);
489 ct->mutable_linear()->add_vars(target_ref);
490 ct->mutable_linear()->add_coeffs(-1);
491 ct->mutable_linear()->add_domain(0);
492 ct->mutable_linear()->add_domain(0);
496 context->UpdateRuleStats(
"element: expanded");
500void ExpandElement(ConstraintProto*
ct, PresolveContext*
context) {
501 const ElementConstraintProto& element =
ct->element();
503 const int index_ref = element.index();
504 const int target_ref = element.target();
505 const int size = element.vars_size();
509 if (!
context->IntersectDomainWith(index_ref, Domain(0, size - 1))) {
510 VLOG(1) <<
"Empty domain for the index variable in ExpandElement()";
515 if (index_ref == target_ref) {
516 ExpandElementWithTargetEqualIndex(
ct,
context);
521 bool all_constants =
true;
522 std::vector<int64_t> valid_indices;
523 const Domain index_domain =
context->DomainOf(index_ref);
524 const Domain target_domain =
context->DomainOf(target_ref);
525 Domain reached_domain;
526 for (
const int64_t v : index_domain.Values()) {
527 const Domain var_domain =
context->DomainOf(element.vars(v));
528 if (var_domain.IntersectionWith(target_domain).IsEmpty())
continue;
530 valid_indices.push_back(v);
531 reached_domain = reached_domain.UnionWith(var_domain);
532 if (var_domain.Min() != var_domain.Max()) {
533 all_constants =
false;
537 if (valid_indices.size() < index_domain.Size()) {
538 if (!
context->IntersectDomainWith(index_ref,
540 VLOG(1) <<
"No compatible variable domains in ExpandElement()";
544 context->UpdateRuleStats(
"element: reduced index domain");
549 bool target_domain_changed =
false;
550 if (!
context->IntersectDomainWith(target_ref, reached_domain,
551 &target_domain_changed)) {
555 if (target_domain_changed) {
556 context->UpdateRuleStats(
"element: reduced target domain");
569void LinkLiteralsAndValues(
const std::vector<int>& literals,
570 const std::vector<int64_t>& values,
571 const absl::flat_hash_map<int64_t, int>& encoding,
573 CHECK_EQ(literals.size(), values.size());
579 absl::btree_map<int, std::vector<int>> encoding_lit_to_support;
584 for (
int i = 0; i < values.size(); ++i) {
585 encoding_lit_to_support[encoding.at(values[i])].push_back(literals[i]);
590 for (
const auto& [encoding_lit, support] : encoding_lit_to_support) {
591 CHECK(!support.empty());
592 if (support.size() == 1) {
593 context->StoreBooleanEqualityRelation(encoding_lit, support[0]);
595 BoolArgumentProto* bool_or =
596 context->working_model->add_constraints()->mutable_bool_or();
597 bool_or->add_literals(
NegatedRef(encoding_lit));
598 for (
const int lit : support) {
599 bool_or->add_literals(lit);
600 context->AddImplication(lit, encoding_lit);
608void AddImplyInReachableValues(
int literal,
609 std::vector<int64_t>& reachable_values,
610 const absl::flat_hash_map<int64_t, int> encoding,
613 if (reachable_values.size() == encoding.size())
return;
614 if (reachable_values.size() <= encoding.size() / 2) {
616 ConstraintProto*
ct =
context->working_model->add_constraints();
618 BoolArgumentProto* bool_or =
ct->mutable_bool_or();
619 for (
const int64_t v : reachable_values) {
620 bool_or->add_literals(encoding.at(v));
624 absl::flat_hash_set<int64_t> set(reachable_values.begin(),
625 reachable_values.end());
626 ConstraintProto*
ct =
context->working_model->add_constraints();
628 BoolArgumentProto* bool_and =
ct->mutable_bool_and();
630 if (!set.contains(
value)) {
637void ExpandAutomaton(ConstraintProto*
ct, PresolveContext*
context) {
638 AutomatonConstraintProto&
proto = *
ct->mutable_automaton();
640 if (
proto.vars_size() == 0) {
641 const int64_t initial_state =
proto.starting_state();
642 for (
const int64_t final_state :
proto.final_states()) {
643 if (initial_state == final_state) {
644 context->UpdateRuleStats(
"automaton: empty and trivially feasible");
649 return (
void)
context->NotifyThatModelIsUnsat(
650 "automaton: empty with an initial state not in the final states.");
651 }
else if (
proto.transition_label_size() == 0) {
652 return (
void)
context->NotifyThatModelIsUnsat(
653 "automaton: non-empty with no transition.");
656 const int n =
proto.vars_size();
657 const std::vector<int> vars = {
proto.vars().begin(),
proto.vars().end()};
660 const absl::flat_hash_set<int64_t> final_states(
661 {
proto.final_states().begin(),
proto.final_states().end()});
662 std::vector<absl::flat_hash_set<int64_t>> reachable_states(n + 1);
663 reachable_states[0].insert(
proto.starting_state());
667 for (
int t = 0; t <
proto.transition_tail_size(); ++t) {
668 const int64_t
tail =
proto.transition_tail(t);
669 const int64_t label =
proto.transition_label(t);
670 const int64_t
head =
proto.transition_head(t);
671 if (!reachable_states[
time].contains(
tail))
continue;
672 if (!
context->DomainContains(vars[
time], label))
continue;
673 if (
time == n - 1 && !final_states.contains(
head))
continue;
674 reachable_states[
time + 1].insert(
head);
680 absl::flat_hash_set<int64_t> new_set;
681 for (
int t = 0; t <
proto.transition_tail_size(); ++t) {
682 const int64_t
tail =
proto.transition_tail(t);
683 const int64_t label =
proto.transition_label(t);
684 const int64_t
head =
proto.transition_head(t);
686 if (!reachable_states[
time].contains(
tail))
continue;
687 if (!
context->DomainContains(vars[
time], label))
continue;
688 if (!reachable_states[
time + 1].contains(
head))
continue;
689 new_set.insert(
tail);
691 reachable_states[
time].swap(new_set);
699 absl::flat_hash_map<int64_t, int> encoding;
700 absl::flat_hash_map<int64_t, int> in_encoding;
701 absl::flat_hash_map<int64_t, int> out_encoding;
702 bool removed_values =
false;
708 std::vector<int64_t> in_states;
709 std::vector<int64_t> labels;
710 std::vector<int64_t> out_states;
711 for (
int i = 0; i <
proto.transition_label_size(); ++i) {
712 const int64_t
tail =
proto.transition_tail(i);
713 const int64_t label =
proto.transition_label(i);
714 const int64_t
head =
proto.transition_head(i);
716 if (!reachable_states[
time].contains(
tail))
continue;
717 if (!reachable_states[
time + 1].contains(
head))
continue;
718 if (!
context->DomainContains(vars[
time], label))
continue;
723 in_states.push_back(
tail);
724 labels.push_back(label);
728 out_states.push_back(
time + 1 == n ? 0 :
head);
732 const int num_tuples = in_states.size();
733 if (num_tuples == 1) {
734 if (!
context->IntersectDomainWith(vars[
time], Domain(labels.front()))) {
735 VLOG(1) <<
"Infeasible automaton.";
744 std::vector<int64_t> transitions = labels;
748 if (!
context->IntersectDomainWith(
750 VLOG(1) <<
"Infeasible automaton.";
757 for (
const int64_t v :
context->DomainOf(vars[
time]).Values()) {
758 encoding[v] =
context->GetOrCreateVarValueEncoding(vars[
time], v);
765 absl::flat_hash_map<int64_t, int> in_count;
766 absl::flat_hash_map<int64_t, int> transition_count;
767 absl::flat_hash_map<int64_t, int> out_count;
768 for (
int i = 0; i < num_tuples; ++i) {
769 in_count[in_states[i]]++;
770 transition_count[labels[i]]++;
771 out_count[out_states[i]]++;
778 std::vector<int64_t> states = out_states;
781 out_encoding.clear();
782 if (states.size() == 2) {
784 out_encoding[states[0]] =
var;
786 }
else if (states.size() > 2) {
787 struct UniqueDetector {
788 void Set(int64_t v) {
789 if (!is_unique)
return;
791 if (v !=
value) is_unique =
false;
798 bool is_unique =
true;
804 absl::flat_hash_map<int64_t, UniqueDetector> out_to_in;
805 absl::flat_hash_map<int64_t, UniqueDetector> out_to_transition;
806 for (
int i = 0; i < num_tuples; ++i) {
807 out_to_in[out_states[i]].Set(in_states[i]);
808 out_to_transition[out_states[i]].Set(labels[i]);
811 for (
const int64_t state : states) {
814 if (!in_encoding.empty() && out_to_in[state].is_unique) {
815 const int64_t unique_in = out_to_in[state].value;
816 if (in_count[unique_in] == out_count[state]) {
817 out_encoding[state] = in_encoding[unique_in];
824 if (!encoding.empty() && out_to_transition[state].is_unique) {
825 const int64_t unique_transition = out_to_transition[state].value;
826 if (transition_count[unique_transition] == out_count[state]) {
827 out_encoding[state] = encoding[unique_transition];
832 out_encoding[state] =
context->NewBoolVar();
850 const int num_involved_variables =
851 in_encoding.size() + encoding.size() + out_encoding.size();
852 const bool use_light_encoding = (num_tuples > num_involved_variables);
853 if (use_light_encoding && !in_encoding.empty() && !encoding.empty() &&
854 !out_encoding.empty()) {
858 absl::flat_hash_map<int64_t, std::vector<int64_t>> in_to_label;
859 absl::flat_hash_map<int64_t, std::vector<int64_t>> in_to_out;
860 for (
int i = 0; i < num_tuples; ++i) {
861 in_to_label[in_states[i]].push_back(labels[i]);
862 in_to_out[in_states[i]].push_back(out_states[i]);
864 for (
const auto [in_value, in_literal] : in_encoding) {
865 AddImplyInReachableValues(in_literal, in_to_label[in_value], encoding,
867 AddImplyInReachableValues(in_literal, in_to_out[in_value], out_encoding,
872 for (
int i = 0; i < num_tuples; ++i) {
874 context->working_model->add_constraints()->mutable_bool_or();
875 bool_or->add_literals(
NegatedRef(in_encoding.at(in_states[i])));
876 bool_or->add_literals(
NegatedRef(encoding.at(labels[i])));
877 bool_or->add_literals(out_encoding.at(out_states[i]));
880 in_encoding.swap(out_encoding);
881 out_encoding.clear();
889 std::vector<int> tuple_literals;
890 if (num_tuples == 2) {
891 const int bool_var =
context->NewBoolVar();
892 tuple_literals.push_back(bool_var);
893 tuple_literals.push_back(
NegatedRef(bool_var));
898 BoolArgumentProto* exactly_one =
899 context->working_model->add_constraints()->mutable_exactly_one();
900 for (
int i = 0; i < num_tuples; ++i) {
902 if (in_count[in_states[i]] == 1 && !in_encoding.empty()) {
903 tuple_literal = in_encoding[in_states[i]];
904 }
else if (transition_count[labels[i]] == 1 && !encoding.empty()) {
905 tuple_literal = encoding[labels[i]];
906 }
else if (out_count[out_states[i]] == 1 && !out_encoding.empty()) {
907 tuple_literal = out_encoding[out_states[i]];
909 tuple_literal =
context->NewBoolVar();
912 tuple_literals.push_back(tuple_literal);
913 exactly_one->add_literals(tuple_literal);
917 if (!in_encoding.empty()) {
918 LinkLiteralsAndValues(tuple_literals, in_states, in_encoding,
context);
920 if (!encoding.empty()) {
921 LinkLiteralsAndValues(tuple_literals, labels, encoding,
context);
923 if (!out_encoding.empty()) {
924 LinkLiteralsAndValues(tuple_literals, out_states, out_encoding,
context);
927 in_encoding.swap(out_encoding);
928 out_encoding.clear();
931 if (removed_values) {
932 context->UpdateRuleStats(
"automaton: reduced variable domains");
934 context->UpdateRuleStats(
"automaton: expanded");
938void ExpandNegativeTable(ConstraintProto*
ct, PresolveContext*
context) {
939 TableConstraintProto& table = *
ct->mutable_table();
940 const int num_vars = table.vars_size();
941 const int num_original_tuples = table.values_size() / num_vars;
942 std::vector<std::vector<int64_t>> tuples(num_original_tuples);
944 for (
int i = 0; i < num_original_tuples; ++i) {
945 for (
int j = 0; j < num_vars; ++j) {
946 tuples[i].push_back(table.values(count++));
950 if (tuples.empty()) {
951 context->UpdateRuleStats(
"table: empty negated constraint");
958 std::vector<int64_t> domain_sizes;
959 for (
int i = 0; i < num_vars; ++i) {
960 domain_sizes.push_back(
context->DomainOf(table.vars(i)).Size());
965 std::vector<int> clause;
966 for (
const std::vector<int64_t>& tuple : tuples) {
968 for (
int i = 0; i < num_vars; ++i) {
969 const int64_t
value = tuple[i];
970 if (
value == any_value)
continue;
973 context->GetOrCreateVarValueEncoding(table.vars(i),
value);
978 BoolArgumentProto* bool_or =
979 context->working_model->add_constraints()->mutable_bool_or();
980 for (
const int lit : clause) {
981 bool_or->add_literals(lit);
984 context->UpdateRuleStats(
"table: expanded negated constraint");
994void ProcessOneVariable(
const std::vector<int>& tuple_literals,
995 const std::vector<int64_t>& values,
int variable,
996 int64_t any_value, PresolveContext*
context) {
997 VLOG(2) <<
"Process var(" << variable <<
") with domain "
998 <<
context->DomainOf(variable) <<
" and " << values.size()
1000 CHECK_EQ(tuple_literals.size(), values.size());
1003 std::vector<int> tuples_with_any;
1004 std::vector<std::pair<int64_t, int>> pairs;
1005 for (
int i = 0; i < values.size(); ++i) {
1006 const int64_t
value = values[i];
1007 if (
value == any_value) {
1008 tuples_with_any.push_back(tuple_literals[i]);
1012 pairs.emplace_back(
value, tuple_literals[i]);
1017 std::vector<int> selected;
1018 std::sort(pairs.begin(), pairs.end());
1019 for (
int i = 0; i < pairs.size();) {
1021 const int64_t
value = pairs[i].first;
1022 for (; i < pairs.size() && pairs[i].first ==
value; ++i) {
1023 selected.push_back(pairs[i].second);
1026 CHECK(!selected.empty() || !tuples_with_any.empty());
1027 if (selected.size() == 1 && tuples_with_any.empty()) {
1028 context->InsertVarValueEncoding(selected.front(), variable,
value);
1030 const int value_literal =
1032 BoolArgumentProto* no_support =
1033 context->working_model->add_constraints()->mutable_bool_or();
1034 for (
const int lit : selected) {
1035 no_support->add_literals(lit);
1036 context->AddImplication(lit, value_literal);
1038 for (
const int lit : tuples_with_any) {
1039 no_support->add_literals(lit);
1043 no_support->add_literals(
NegatedRef(value_literal));
1049void AddSizeTwoTable(
1050 const std::vector<int>& vars,
1051 const std::vector<std::vector<int64_t>>& tuples,
1052 const std::vector<absl::flat_hash_set<int64_t>>& values_per_var,
1055 const int left_var = vars[0];
1056 const int right_var = vars[1];
1057 if (
context->DomainOf(left_var).IsFixed() ||
1058 context->DomainOf(right_var).IsFixed()) {
1064 absl::btree_map<int, std::vector<int>> left_to_right;
1065 absl::btree_map<int, std::vector<int>> right_to_left;
1067 for (
const auto& tuple : tuples) {
1068 const int64_t left_value(tuple[0]);
1069 const int64_t right_value(tuple[1]);
1071 CHECK(
context->DomainContains(right_var, right_value));
1073 const int left_literal =
1074 context->GetOrCreateVarValueEncoding(left_var, left_value);
1075 const int right_literal =
1076 context->GetOrCreateVarValueEncoding(right_var, right_value);
1077 left_to_right[left_literal].push_back(right_literal);
1078 right_to_left[right_literal].push_back(left_literal);
1081 int num_implications = 0;
1082 int num_clause_added = 0;
1083 int num_large_clause_added = 0;
1084 auto add_support_constraint =
1085 [
context, &num_clause_added, &num_large_clause_added, &num_implications](
1086 int lit,
const std::vector<int>& support_literals,
1087 int max_support_size) {
1088 if (support_literals.size() == max_support_size)
return;
1089 if (support_literals.size() == 1) {
1090 context->AddImplication(lit, support_literals.front());
1093 BoolArgumentProto* bool_or =
1094 context->working_model->add_constraints()->mutable_bool_or();
1095 for (
const int support_literal : support_literals) {
1096 bool_or->add_literals(support_literal);
1100 if (support_literals.size() > max_support_size / 2) {
1101 num_large_clause_added++;
1106 for (
const auto& it : left_to_right) {
1107 add_support_constraint(it.first, it.second, values_per_var[1].size());
1109 for (
const auto& it : right_to_left) {
1110 add_support_constraint(it.first, it.second, values_per_var[0].size());
1112 VLOG(2) <<
"Table: 2 variables, " << tuples.size() <<
" tuples encoded using "
1113 << num_clause_added <<
" clauses, including "
1114 << num_large_clause_added <<
" large clauses, " << num_implications
1118void ExpandPositiveTable(ConstraintProto*
ct, PresolveContext*
context) {
1119 const TableConstraintProto& table =
ct->table();
1120 const int num_vars = table.vars_size();
1121 const int num_original_tuples = table.values_size() / num_vars;
1124 const std::vector<int> vars(table.vars().begin(), table.vars().end());
1125 std::vector<std::vector<int64_t>> tuples(num_original_tuples);
1127 for (
int tuple_index = 0; tuple_index < num_original_tuples; ++tuple_index) {
1128 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1129 tuples[tuple_index].push_back(table.values(count++));
1135 std::vector<absl::flat_hash_set<int64_t>> values_per_var(num_vars);
1137 for (
int tuple_index = 0; tuple_index < num_original_tuples; ++tuple_index) {
1139 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1140 const int64_t
value = tuples[tuple_index][var_index];
1141 if (!
context->DomainContains(vars[var_index],
value)) {
1147 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1148 values_per_var[var_index].insert(tuples[tuple_index][var_index]);
1150 std::swap(tuples[tuple_index], tuples[new_size]);
1154 tuples.resize(new_size);
1155 const int num_valid_tuples = tuples.size();
1157 if (tuples.empty()) {
1158 context->UpdateRuleStats(
"table: empty");
1159 return (
void)
context->NotifyThatModelIsUnsat();
1165 int num_fixed_variables = 0;
1166 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1170 values_per_var[var_index].end()})));
1171 if (
context->DomainOf(vars[var_index]).IsFixed()) {
1172 num_fixed_variables++;
1176 if (num_fixed_variables == num_vars - 1) {
1177 context->UpdateRuleStats(
"table: one variable not fixed");
1180 }
else if (num_fixed_variables == num_vars) {
1181 context->UpdateRuleStats(
"table: all variables fixed");
1187 if (num_vars == 2) {
1188 AddSizeTwoTable(vars, tuples, values_per_var,
context);
1190 "table: expanded positive constraint with two variables");
1197 int num_prefix_tuples = 0;
1199 absl::flat_hash_set<absl::Span<const int64_t>> prefixes;
1200 for (
const std::vector<int64_t>& tuple : tuples) {
1201 prefixes.insert(absl::MakeSpan(tuple.data(), num_vars - 1));
1203 num_prefix_tuples = prefixes.size();
1210 std::vector<int64_t> domain_sizes;
1211 for (
int i = 0; i < num_vars; ++i) {
1212 domain_sizes.push_back(values_per_var[i].size());
1214 const int num_tuples_before_compression = tuples.size();
1216 const int num_compressed_tuples = tuples.size();
1217 if (num_compressed_tuples < num_tuples_before_compression) {
1218 context->UpdateRuleStats(
"table: compress tuples");
1221 if (num_compressed_tuples == 1) {
1223 context->UpdateRuleStats(
"table: one tuple");
1229 const bool prefixes_are_all_different = num_prefix_tuples == num_valid_tuples;
1230 if (prefixes_are_all_different) {
1232 "TODO table: last value implied by previous values");
1242 int64_t max_num_prefix_tuples = 1;
1243 for (
int var_index = 0; var_index + 1 < num_vars; ++var_index) {
1244 max_num_prefix_tuples =
1245 CapProd(max_num_prefix_tuples, values_per_var[var_index].size());
1249 absl::StrCat(
"Table: ", num_vars,
1250 " variables, original tuples = ", num_original_tuples);
1251 if (num_valid_tuples != num_original_tuples) {
1252 absl::StrAppend(&
message,
", valid tuples = ", num_valid_tuples);
1254 if (prefixes_are_all_different) {
1255 if (num_prefix_tuples < max_num_prefix_tuples) {
1256 absl::StrAppend(&
message,
", partial prefix = ", num_prefix_tuples,
"/",
1257 max_num_prefix_tuples);
1259 absl::StrAppend(&
message,
", full prefix = true");
1262 absl::StrAppend(&
message,
", num prefix tuples = ", num_prefix_tuples);
1264 if (num_compressed_tuples != num_valid_tuples) {
1266 ", compressed tuples = ", num_compressed_tuples);
1272 if (num_compressed_tuples == 2) {
1273 context->UpdateRuleStats(
"TODO table: two tuples");
1278 std::vector<int> tuple_literals(num_compressed_tuples);
1279 BoolArgumentProto* exactly_one =
1280 context->working_model->add_constraints()->mutable_exactly_one();
1281 for (
int i = 0; i < num_compressed_tuples; ++i) {
1282 tuple_literals[i] =
context->NewBoolVar();
1283 exactly_one->add_literals(tuple_literals[i]);
1286 std::vector<int64_t> values(num_compressed_tuples);
1287 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1288 if (values_per_var[var_index].size() == 1)
continue;
1289 for (
int i = 0; i < num_compressed_tuples; ++i) {
1290 values[i] = tuples[i][var_index];
1292 ProcessOneVariable(tuple_literals, values, vars[var_index], any_value,
1296 context->UpdateRuleStats(
"table: expanded positive constraint");
1300bool AllDiffShouldBeExpanded(
const Domain& union_of_domains,
1301 ConstraintProto*
ct, PresolveContext*
context) {
1302 const AllDifferentConstraintProto&
proto = *
ct->mutable_all_diff();
1303 const int num_exprs =
proto.exprs_size();
1304 int num_fully_encoded = 0;
1305 for (
int i = 0; i < num_exprs; ++i) {
1307 num_fully_encoded++;
1311 if ((union_of_domains.Size() <= 2 *
proto.exprs_size()) ||
1312 (union_of_domains.Size() <= 32)) {
1317 if (num_fully_encoded == num_exprs && union_of_domains.Size() < 256) {
1324void ExpandAllDiff(
bool force_alldiff_expansion, ConstraintProto*
ct,
1326 AllDifferentConstraintProto&
proto = *
ct->mutable_all_diff();
1327 if (
proto.exprs_size() <= 1)
return;
1329 const int num_exprs =
proto.exprs_size();
1330 Domain union_of_domains =
context->DomainSuperSetOf(
proto.exprs(0));
1331 for (
int i = 1; i < num_exprs; ++i) {
1333 union_of_domains.UnionWith(
context->DomainSuperSetOf(
proto.exprs(i)));
1336 if (!AllDiffShouldBeExpanded(union_of_domains,
ct,
context) &&
1337 !force_alldiff_expansion) {
1341 const bool is_a_permutation = num_exprs == union_of_domains.Size();
1346 for (
const int64_t v : union_of_domains.Values()) {
1348 std::vector<LinearExpressionProto> possible_exprs;
1349 int fixed_expression_count = 0;
1350 for (
const LinearExpressionProto& expr :
proto.exprs()) {
1351 if (!
context->DomainContains(expr, v))
continue;
1352 possible_exprs.push_back(expr);
1354 fixed_expression_count++;
1358 if (fixed_expression_count > 1) {
1360 return (
void)
context->NotifyThatModelIsUnsat();
1361 }
else if (fixed_expression_count == 1) {
1363 for (
const LinearExpressionProto& expr : possible_exprs) {
1364 if (
context->IsFixed(expr))
continue;
1365 if (!
context->IntersectDomainWith(expr, Domain(v).Complement())) {
1366 VLOG(1) <<
"Empty domain for a variable in ExpandAllDiff()";
1372 BoolArgumentProto* at_most_or_equal_one =
1374 ?
context->working_model->add_constraints()->mutable_exactly_one()
1375 :
context->working_model->add_constraints()->mutable_at_most_one();
1376 for (
const LinearExpressionProto& expr : possible_exprs) {
1379 if (!
context->DomainContains(expr, v))
continue;
1384 const int encoding =
context->GetOrCreateAffineValueEncoding(expr, v);
1385 at_most_or_equal_one->add_literals(encoding);
1388 if (is_a_permutation) {
1389 context->UpdateRuleStats(
"all_diff: permutation expanded");
1391 context->UpdateRuleStats(
"all_diff: expanded");
1402void ExpandSomeLinearOfSizeTwo(ConstraintProto*
ct, PresolveContext*
context) {
1403 const LinearConstraintProto& arg =
ct->linear();
1404 if (arg.vars_size() != 2)
return;
1406 const int var1 = arg.vars(0);
1407 const int var2 = arg.vars(1);
1410 const int64_t coeff1 = arg.coeffs(0);
1411 const int64_t coeff2 = arg.coeffs(1);
1413 const Domain reachable_rhs_superset =
1414 context->DomainOf(var1).MultiplicationBy(coeff1).AdditionWith(
1415 context->DomainOf(var2).MultiplicationBy(coeff2));
1417 const Domain infeasible_reachable_values =
1418 reachable_rhs_superset.IntersectionWith(
1422 if (infeasible_reachable_values.Size() != 1)
return;
1427 int64_t cte = infeasible_reachable_values.FixedValue();
1432 context->UpdateRuleStats(
"linear: expand always feasible ax + by != cte");
1436 const Domain reduced_domain =
1438 .AdditionWith(Domain(-x0))
1439 .InverseMultiplicationBy(
b)
1440 .IntersectionWith(
context->DomainOf(var2)
1441 .AdditionWith(Domain(-y0))
1442 .InverseMultiplicationBy(-
a));
1444 if (reduced_domain.Size() > 16)
return;
1449 const int64_t size1 =
context->DomainOf(var1).Size();
1450 const int64_t size2 =
context->DomainOf(var2).Size();
1451 for (
const int64_t z : reduced_domain.Values()) {
1452 const int64_t value1 = x0 +
b * z;
1453 const int64_t value2 = y0 -
a * z;
1454 DCHECK(
context->DomainContains(var1, value1)) <<
"value1 = " << value1;
1455 DCHECK(
context->DomainContains(var2, value2)) <<
"value2 = " << value2;
1456 DCHECK_EQ(coeff1 * value1 + coeff2 * value2,
1457 infeasible_reachable_values.FixedValue());
1459 if (!
context->HasVarValueEncoding(var1, value1,
nullptr) || size1 == 2) {
1462 if (!
context->HasVarValueEncoding(var2, value2,
nullptr) || size2 == 2) {
1469 for (
const int64_t z : reduced_domain.Values()) {
1470 const int64_t value1 = x0 +
b * z;
1471 const int64_t value2 = y0 -
a * z;
1473 const int lit1 =
context->GetOrCreateVarValueEncoding(var1, value1);
1474 const int lit2 =
context->GetOrCreateVarValueEncoding(var2, value2);
1476 context->working_model->add_constraints()->mutable_bool_or();
1479 for (
const int lit :
ct->enforcement_literal()) {
1484 context->UpdateRuleStats(
"linear: expand small ax + by != cte");
1491 if (
context->params().disable_constraint_expansion())
return;
1492 if (
context->ModelIsUnsat())
return;
1496 if (
context->ModelIsExpanded())
return;
1499 context->InitializeNewDomains();
1502 context->ClearPrecedenceCache();
1505 for (
int i = 0; i <
context->working_model->constraints_size(); ++i) {
1506 ConstraintProto*
const ct =
context->working_model->mutable_constraints(i);
1508 switch (
ct->constraint_case()) {
1509 case ConstraintProto::ConstraintCase::kReservoir:
1512 case ConstraintProto::ConstraintCase::kIntMod:
1515 case ConstraintProto::ConstraintCase::kIntProd:
1518 case ConstraintProto::ConstraintCase::kElement:
1521 case ConstraintProto::ConstraintCase::kInverse:
1524 case ConstraintProto::ConstraintCase::kAutomaton:
1527 case ConstraintProto::ConstraintCase::kTable:
1528 if (
ct->table().negated()) {
1541 context->UpdateNewConstraintsVariableUsage();
1542 if (
ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) {
1543 context->UpdateConstraintVariableUsage(i);
1547 if (
context->ModelIsUnsat()) {
1556 for (
int i = 0; i <
context->working_model->constraints_size(); ++i) {
1557 ConstraintProto*
const ct =
context->working_model->mutable_constraints(i);
1559 switch (
ct->constraint_case()) {
1560 case ConstraintProto::ConstraintCase::kAllDiff:
1561 ExpandAllDiff(
context->params().expand_alldiff_constraints(),
ct,
1564 case ConstraintProto::ConstraintCase::kLinear:
1575 context->UpdateNewConstraintsVariableUsage();
1576 if (
ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) {
1577 context->UpdateConstraintVariableUsage(i);
1581 if (
context->ModelIsUnsat()) {
1591 context->ClearPrecedenceCache();
1594 context->InitializeNewDomains();
1597 for (
int i = 0; i <
context->working_model->variables_size(); ++i) {
1599 context->working_model->mutable_variables(i));
1602 context->NotifyThatModelIsExpanded();
#define CHECK_EQ(val1, val2)
#define DCHECK(condition)
#define DCHECK_EQ(val1, val2)
#define VLOG(verboselevel)
Domain Complement() const
Returns the set Int64 ∖ D.
static Domain FromValues(std::vector< int64_t > values)
Creates a domain from the union of an unsorted list of integer values.
GurobiMPCallbackContext * context
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
void swap(IdMap< K, V > &a, IdMap< K, V > &b)
bool RefIsPositive(int ref)
void CompressTuples(absl::Span< const int64_t > domain_sizes, int64_t any_value, std::vector< std::vector< int64_t > > *tuples)
void ExpandCpModel(PresolveContext *context)
bool SolveDiophantineEquationOfSizeTwo(int64_t &a, int64_t &b, int64_t &cte, int64_t &x0, int64_t &y0)
void FillDomainInProto(const Domain &domain, ProtoWithDomain *proto)
Domain ReadDomainFromProto(const ProtoWithDomain &proto)
void AddLinearExpressionToLinearConstraint(const LinearExpressionProto &expr, int64_t coefficient, LinearConstraintProto *linear)
Collection of objects used to extend the Constraint Solver library.
int64_t CapAdd(int64_t x, int64_t y)
int64_t CapSub(int64_t x, int64_t y)
std::string ProtobufShortDebugString(const P &message)
int64_t CapProd(int64_t x, int64_t y)
#define SOLVER_LOG(logger,...)
#define VLOG_IS_ON(verboselevel)