22#include "absl/container/flat_hash_map.h"
38void ExpandReservoir(ConstraintProto*
ct, PresolveContext*
context) {
39 if (
ct->reservoir().min_level() >
ct->reservoir().max_level()) {
40 VLOG(1) <<
"Empty level domain in reservoir constraint.";
41 return (
void)
context->NotifyThatModelIsUnsat();
44 const ReservoirConstraintProto& reservoir =
ct->reservoir();
45 const int num_events = reservoir.time_exprs_size();
47 const int true_literal =
context->GetOrCreateConstantVar(1);
49 const auto is_active_literal = [&reservoir, true_literal](
int index) {
50 if (reservoir.active_literals_size() == 0)
return true_literal;
51 return reservoir.active_literals(
index);
54 int num_positives = 0;
55 int num_negatives = 0;
56 for (
const int64_t
demand : reservoir.level_changes()) {
64 absl::flat_hash_map<std::pair<int, int>,
int> precedence_cache;
66 if (num_positives > 0 && num_negatives > 0) {
68 for (
int i = 0; i < num_events - 1; ++i) {
69 const int active_i = is_active_literal(i);
70 if (
context->LiteralIsFalse(active_i))
continue;
71 const LinearExpressionProto& time_i = reservoir.time_exprs(i);
73 for (
int j = i + 1; j < num_events; ++j) {
74 const int active_j = is_active_literal(j);
75 if (
context->LiteralIsFalse(active_j))
continue;
76 const LinearExpressionProto& time_j = reservoir.time_exprs(j);
78 const int i_lesseq_j =
context->GetOrCreateReifiedPrecedenceLiteral(
79 time_i, time_j, active_i, active_j);
80 context->working_model->mutable_variables(i_lesseq_j)
81 ->set_name(absl::StrCat(i,
" before ", j));
82 precedence_cache[{i, j}] = i_lesseq_j;
83 const int j_lesseq_i =
context->GetOrCreateReifiedPrecedenceLiteral(
84 time_j, time_i, active_j, active_i);
85 context->working_model->mutable_variables(j_lesseq_i)
86 ->set_name(absl::StrCat(j,
" before ", i));
87 precedence_cache[{j, i}] = j_lesseq_i;
95 for (
int i = 0; i < num_events; ++i) {
96 const int active_i = is_active_literal(i);
97 if (
context->LiteralIsFalse(active_i))
continue;
100 ConstraintProto*
const level =
context->working_model->add_constraints();
101 level->add_enforcement_literal(active_i);
105 for (
int j = 0; j < num_events; ++j) {
106 if (i == j)
continue;
107 const int active_j = is_active_literal(j);
108 if (
context->LiteralIsFalse(active_j))
continue;
110 const auto prec_it = precedence_cache.find({j, i});
111 CHECK(prec_it != precedence_cache.end());
112 const int prec_lit = prec_it->second;
113 const int64_t
demand = reservoir.level_changes(j);
115 level->mutable_linear()->add_vars(prec_lit);
116 level->mutable_linear()->add_coeffs(
demand);
118 level->mutable_linear()->add_vars(prec_lit);
119 level->mutable_linear()->add_coeffs(-
demand);
125 const int64_t demand_i = reservoir.level_changes(i);
126 level->mutable_linear()->add_domain(
127 CapAdd(
CapSub(reservoir.min_level(), demand_i), offset));
128 level->mutable_linear()->add_domain(
129 CapAdd(
CapSub(reservoir.max_level(), demand_i), offset));
135 context->working_model->add_constraints()->mutable_linear();
136 for (
int i = 0; i < num_events; ++i) {
137 sum->add_vars(is_active_literal(i));
138 sum->add_coeffs(reservoir.level_changes(i));
140 sum->add_domain(reservoir.min_level());
141 sum->add_domain(reservoir.max_level());
145 context->UpdateRuleStats(
"reservoir: expanded");
148void ExpandIntMod(ConstraintProto*
ct, PresolveContext*
context) {
149 const LinearArgumentProto& int_mod =
ct->int_mod();
150 const LinearExpressionProto& mod_expr = int_mod.exprs(1);
151 if (
context->IsFixed(mod_expr))
return;
153 const LinearExpressionProto& expr = int_mod.exprs(0);
154 const LinearExpressionProto& target_expr = int_mod.target();
157 if (!
context->IntersectDomainWith(
158 target_expr,
context->DomainSuperSetOf(expr).PositiveModuloBySuperset(
159 context->DomainSuperSetOf(mod_expr)))) {
164 auto new_enforced_constraint = [&]() {
165 ConstraintProto* new_ct =
context->working_model->add_constraints();
166 *new_ct->mutable_enforcement_literal() =
ct->enforcement_literal();
171 const int div_var =
context->NewIntVar(
172 context->DomainSuperSetOf(expr).PositiveDivisionBySuperset(
173 context->DomainSuperSetOf(mod_expr)));
174 LinearExpressionProto div_expr;
175 div_expr.add_vars(div_var);
176 div_expr.add_coeffs(1);
178 LinearArgumentProto*
const div_proto =
179 new_enforced_constraint()->mutable_int_div();
180 *div_proto->mutable_target() = div_expr;
181 *div_proto->add_exprs() = expr;
182 *div_proto->add_exprs() = mod_expr;
185 const Domain prod_domain =
187 .ContinuousMultiplicationBy(
context->DomainSuperSetOf(mod_expr))
188 .IntersectionWith(
context->DomainSuperSetOf(expr).AdditionWith(
189 context->DomainSuperSetOf(target_expr).Negation()));
190 const int prod_var =
context->NewIntVar(prod_domain);
191 LinearExpressionProto prod_expr;
192 prod_expr.add_vars(prod_var);
193 prod_expr.add_coeffs(1);
195 LinearArgumentProto*
const int_prod =
196 new_enforced_constraint()->mutable_int_prod();
197 *int_prod->mutable_target() = prod_expr;
198 *int_prod->add_exprs() = div_expr;
199 *int_prod->add_exprs() = mod_expr;
202 LinearConstraintProto*
const lin =
203 new_enforced_constraint()->mutable_linear();
211 context->UpdateRuleStats(
"int_mod: expanded");
215void ExpandIntProdWithBoolean(
int bool_ref,
216 const LinearExpressionProto& int_expr,
217 const LinearExpressionProto& product_expr,
219 ConstraintProto*
const one =
context->working_model->add_constraints();
220 one->add_enforcement_literal(bool_ref);
221 one->mutable_linear()->add_domain(0);
222 one->mutable_linear()->add_domain(0);
225 one->mutable_linear());
227 ConstraintProto*
const zero =
context->working_model->add_constraints();
228 zero->add_enforcement_literal(
NegatedRef(bool_ref));
229 zero->mutable_linear()->add_domain(0);
230 zero->mutable_linear()->add_domain(0);
232 zero->mutable_linear());
235void ExpandIntProd(ConstraintProto*
ct, PresolveContext*
context) {
236 const LinearArgumentProto& int_prod =
ct->int_prod();
237 if (int_prod.exprs_size() != 2)
return;
238 const LinearExpressionProto&
a = int_prod.exprs(0);
239 const LinearExpressionProto&
b = int_prod.exprs(1);
240 const LinearExpressionProto& p = int_prod.target();
247 if (a_is_literal && !b_is_literal) {
250 context->UpdateRuleStats(
"int_prod: expanded product with Boolean var");
251 }
else if (b_is_literal) {
254 context->UpdateRuleStats(
"int_prod: expanded product with Boolean var");
258void ExpandInverse(ConstraintProto*
ct, PresolveContext*
context) {
259 const auto& f_direct =
ct->inverse().f_direct();
260 const auto& f_inverse =
ct->inverse().f_inverse();
261 const int n = f_direct.size();
270 absl::flat_hash_set<int> used_variables;
271 for (
const int ref : f_direct) {
273 if (!
context->IntersectDomainWith(ref, Domain(0, n - 1))) {
274 VLOG(1) <<
"Empty domain for a variable in ExpandInverse()";
278 for (
const int ref : f_inverse) {
280 if (!
context->IntersectDomainWith(ref, Domain(0, n - 1))) {
281 VLOG(1) <<
"Empty domain for a variable in ExpandInverse()";
288 if (used_variables.size() != 2 * n) {
289 for (
int i = 0; i < n; ++i) {
290 for (
int j = 0; j < n; ++j) {
295 if (i == j)
continue;
296 if (!
context->IntersectDomainWith(
306 std::vector<int64_t> possible_values;
309 const auto filter_inverse_domain =
310 [
context, n, &possible_values](
const auto& direct,
const auto& inverse) {
312 for (
int i = 0; i < n; ++i) {
313 possible_values.clear();
314 const Domain domain =
context->DomainOf(direct[i]);
315 bool removed_value =
false;
316 for (
const int64_t j : domain.Values()) {
317 if (
context->DomainOf(inverse[j]).Contains(i)) {
318 possible_values.push_back(j);
320 removed_value =
true;
324 if (!
context->IntersectDomainWith(
326 VLOG(1) <<
"Empty domain for a variable in ExpandInverse()";
336 if (!filter_inverse_domain(f_direct, f_inverse))
return;
337 if (!filter_inverse_domain(f_inverse, f_direct))
return;
343 for (
int i = 0; i < n; ++i) {
344 const int f_i = f_direct[i];
345 for (
const int64_t j :
context->DomainOf(f_i).Values()) {
347 const int r_j = f_inverse[j];
349 if (
context->HasVarValueEncoding(r_j, i, &r_j_i)) {
350 context->InsertVarValueEncoding(r_j_i, f_i, j);
352 const int f_i_j =
context->GetOrCreateVarValueEncoding(f_i, j);
353 context->InsertVarValueEncoding(f_i_j, r_j, i);
359 context->UpdateRuleStats(
"inverse: expanded");
363void ExpandElementWithTargetEqualIndex(ConstraintProto*
ct,
365 const ElementConstraintProto& element =
ct->element();
366 DCHECK_EQ(element.index(), element.target());
368 const int index_ref = element.index();
369 std::vector<int64_t> valid_indices;
370 for (
const int64_t v :
context->DomainOf(index_ref).Values()) {
371 if (!
context->DomainContains(element.vars(v), v))
continue;
372 valid_indices.push_back(v);
374 if (valid_indices.size() <
context->DomainOf(index_ref).Size()) {
375 if (!
context->IntersectDomainWith(index_ref,
377 VLOG(1) <<
"No compatible variable domains in "
378 "ExpandElementWithTargetEqualIndex()";
381 context->UpdateRuleStats(
"element: reduced index domain");
384 for (
const int64_t v :
context->DomainOf(index_ref).Values()) {
385 const int var = element.vars(v);
388 context->GetOrCreateVarValueEncoding(index_ref, v),
var, Domain(v));
391 "element: expanded with special case target = index");
396void ExpandConstantArrayElement(ConstraintProto*
ct, PresolveContext*
context) {
397 const ElementConstraintProto& element =
ct->element();
398 const int index_ref = element.index();
399 const int target_ref = element.target();
402 const Domain index_domain =
context->DomainOf(index_ref);
403 const Domain target_domain =
context->DomainOf(target_ref);
410 absl::flat_hash_map<int64_t, BoolArgumentProto*> supports;
412 absl::flat_hash_map<int64_t, int> constant_var_values_usage;
413 for (
const int64_t v : index_domain.Values()) {
416 if (++constant_var_values_usage[
value] == 2) {
418 BoolArgumentProto*
const support =
419 context->working_model->add_constraints()->mutable_bool_or();
420 const int target_literal =
421 context->GetOrCreateVarValueEncoding(target_ref,
value);
422 support->add_literals(
NegatedRef(target_literal));
423 supports[
value] = support;
432 context->working_model->add_constraints()->mutable_exactly_one();
433 for (
const int64_t v : index_domain.Values()) {
434 const int index_literal =
435 context->GetOrCreateVarValueEncoding(index_ref, v);
436 exactly_one->add_literals(index_literal);
439 const auto& it = supports.find(
value);
440 if (it != supports.end()) {
443 const int target_literal =
444 context->GetOrCreateVarValueEncoding(target_ref,
value);
445 context->AddImplication(index_literal, target_literal);
446 it->second->add_literals(index_literal);
449 context->InsertVarValueEncoding(index_literal, target_ref,
value);
454 context->UpdateRuleStats(
"element: expanded value element");
459void ExpandVariableElement(ConstraintProto*
ct, PresolveContext*
context) {
460 const ElementConstraintProto& element =
ct->element();
461 const int index_ref = element.index();
462 const int target_ref = element.target();
463 const Domain index_domain =
context->DomainOf(index_ref);
465 BoolArgumentProto* bool_or =
466 context->working_model->add_constraints()->mutable_bool_or();
468 for (
const int64_t v : index_domain.Values()) {
469 const int var = element.vars(v);
470 const Domain var_domain =
context->DomainOf(
var);
471 const int index_lit =
context->GetOrCreateVarValueEncoding(index_ref, v);
472 bool_or->add_literals(index_lit);
474 if (var_domain.IsFixed()) {
475 context->AddImplyInDomain(index_lit, target_ref, var_domain);
477 ConstraintProto*
const ct =
context->working_model->add_constraints();
478 ct->add_enforcement_literal(index_lit);
479 ct->mutable_linear()->add_vars(
var);
480 ct->mutable_linear()->add_coeffs(1);
481 ct->mutable_linear()->add_vars(target_ref);
482 ct->mutable_linear()->add_coeffs(-1);
483 ct->mutable_linear()->add_domain(0);
484 ct->mutable_linear()->add_domain(0);
488 context->UpdateRuleStats(
"element: expanded");
492void ExpandElement(ConstraintProto*
ct, PresolveContext*
context) {
493 const ElementConstraintProto& element =
ct->element();
495 const int index_ref = element.index();
496 const int target_ref = element.target();
497 const int size = element.vars_size();
501 if (!
context->IntersectDomainWith(index_ref, Domain(0, size - 1))) {
502 VLOG(1) <<
"Empty domain for the index variable in ExpandElement()";
507 if (index_ref == target_ref) {
508 ExpandElementWithTargetEqualIndex(
ct,
context);
513 bool all_constants =
true;
514 std::vector<int64_t> valid_indices;
515 const Domain index_domain =
context->DomainOf(index_ref);
516 const Domain target_domain =
context->DomainOf(target_ref);
517 Domain reached_domain;
518 for (
const int64_t v : index_domain.Values()) {
519 const Domain var_domain =
context->DomainOf(element.vars(v));
520 if (var_domain.IntersectionWith(target_domain).IsEmpty())
continue;
522 valid_indices.push_back(v);
523 reached_domain = reached_domain.UnionWith(var_domain);
524 if (var_domain.Min() != var_domain.Max()) {
525 all_constants =
false;
529 if (valid_indices.size() < index_domain.Size()) {
530 if (!
context->IntersectDomainWith(index_ref,
532 VLOG(1) <<
"No compatible variable domains in ExpandElement()";
536 context->UpdateRuleStats(
"element: reduced index domain");
541 bool target_domain_changed =
false;
542 if (!
context->IntersectDomainWith(target_ref, reached_domain,
543 &target_domain_changed)) {
547 if (target_domain_changed) {
548 context->UpdateRuleStats(
"element: reduced target domain");
561void LinkLiteralsAndValues(
const std::vector<int>& literals,
562 const std::vector<int64_t>& values,
563 const absl::flat_hash_map<int64_t, int>& encoding,
565 CHECK_EQ(literals.size(), values.size());
571 std::map<int, std::vector<int>> encoding_lit_to_support;
576 for (
int i = 0; i < values.size(); ++i) {
577 encoding_lit_to_support[encoding.at(values[i])].push_back(literals[i]);
582 for (
const auto& [encoding_lit, support] : encoding_lit_to_support) {
583 CHECK(!support.empty());
584 if (support.size() == 1) {
585 context->StoreBooleanEqualityRelation(encoding_lit, support[0]);
587 BoolArgumentProto* bool_or =
588 context->working_model->add_constraints()->mutable_bool_or();
589 bool_or->add_literals(
NegatedRef(encoding_lit));
590 for (
const int lit : support) {
591 bool_or->add_literals(lit);
592 context->AddImplication(lit, encoding_lit);
600void AddImplyInReachableValues(
int literal,
601 std::vector<int64_t>& reachable_values,
602 const absl::flat_hash_map<int64_t, int> encoding,
605 if (reachable_values.size() == encoding.size())
return;
606 if (reachable_values.size() <= encoding.size() / 2) {
608 ConstraintProto*
ct =
context->working_model->add_constraints();
610 BoolArgumentProto* bool_or =
ct->mutable_bool_or();
611 for (
const int64_t v : reachable_values) {
612 bool_or->add_literals(encoding.at(v));
616 absl::flat_hash_set<int64_t> set(reachable_values.begin(),
617 reachable_values.end());
618 ConstraintProto*
ct =
context->working_model->add_constraints();
620 BoolArgumentProto* bool_and =
ct->mutable_bool_and();
622 if (!set.contains(
value)) {
629void ExpandAutomaton(ConstraintProto*
ct, PresolveContext*
context) {
630 AutomatonConstraintProto&
proto = *
ct->mutable_automaton();
632 if (
proto.vars_size() == 0) {
633 const int64_t initial_state =
proto.starting_state();
634 for (
const int64_t final_state :
proto.final_states()) {
635 if (initial_state == final_state) {
636 context->UpdateRuleStats(
"automaton: empty and trivially feasible");
641 return (
void)
context->NotifyThatModelIsUnsat(
642 "automaton: empty with an initial state not in the final states.");
643 }
else if (
proto.transition_label_size() == 0) {
644 return (
void)
context->NotifyThatModelIsUnsat(
645 "automaton: non-empty with no transition.");
648 const int n =
proto.vars_size();
649 const std::vector<int> vars = {
proto.vars().begin(),
proto.vars().end()};
652 const absl::flat_hash_set<int64_t> final_states(
653 {
proto.final_states().begin(),
proto.final_states().end()});
654 std::vector<absl::flat_hash_set<int64_t>> reachable_states(n + 1);
655 reachable_states[0].insert(
proto.starting_state());
659 for (
int t = 0; t <
proto.transition_tail_size(); ++t) {
660 const int64_t
tail =
proto.transition_tail(t);
661 const int64_t label =
proto.transition_label(t);
662 const int64_t
head =
proto.transition_head(t);
663 if (!reachable_states[
time].contains(
tail))
continue;
664 if (!
context->DomainContains(vars[
time], label))
continue;
665 if (
time == n - 1 && !final_states.contains(
head))
continue;
666 reachable_states[
time + 1].insert(
head);
672 absl::flat_hash_set<int64_t> new_set;
673 for (
int t = 0; t <
proto.transition_tail_size(); ++t) {
674 const int64_t
tail =
proto.transition_tail(t);
675 const int64_t label =
proto.transition_label(t);
676 const int64_t
head =
proto.transition_head(t);
678 if (!reachable_states[
time].contains(
tail))
continue;
679 if (!
context->DomainContains(vars[
time], label))
continue;
680 if (!reachable_states[
time + 1].contains(
head))
continue;
681 new_set.insert(
tail);
683 reachable_states[
time].
swap(new_set);
691 absl::flat_hash_map<int64_t, int> encoding;
692 absl::flat_hash_map<int64_t, int> in_encoding;
693 absl::flat_hash_map<int64_t, int> out_encoding;
694 bool removed_values =
false;
700 std::vector<int64_t> in_states;
701 std::vector<int64_t> labels;
702 std::vector<int64_t> out_states;
703 for (
int i = 0; i <
proto.transition_label_size(); ++i) {
704 const int64_t
tail =
proto.transition_tail(i);
705 const int64_t label =
proto.transition_label(i);
706 const int64_t
head =
proto.transition_head(i);
708 if (!reachable_states[
time].contains(
tail))
continue;
709 if (!reachable_states[
time + 1].contains(
head))
continue;
710 if (!
context->DomainContains(vars[
time], label))
continue;
715 in_states.push_back(
tail);
716 labels.push_back(label);
720 out_states.push_back(
time + 1 == n ? 0 :
head);
724 const int num_tuples = in_states.size();
725 if (num_tuples == 1) {
726 if (!
context->IntersectDomainWith(vars[
time], Domain(labels.front()))) {
727 VLOG(1) <<
"Infeasible automaton.";
736 std::vector<int64_t> transitions = labels;
740 if (!
context->IntersectDomainWith(
742 VLOG(1) <<
"Infeasible automaton.";
749 for (
const int64_t v :
context->DomainOf(vars[
time]).Values()) {
750 encoding[v] =
context->GetOrCreateVarValueEncoding(vars[
time], v);
757 absl::flat_hash_map<int64_t, int> in_count;
758 absl::flat_hash_map<int64_t, int> transition_count;
759 absl::flat_hash_map<int64_t, int> out_count;
760 for (
int i = 0; i < num_tuples; ++i) {
761 in_count[in_states[i]]++;
762 transition_count[labels[i]]++;
763 out_count[out_states[i]]++;
770 std::vector<int64_t> states = out_states;
773 out_encoding.clear();
774 if (states.size() == 2) {
776 out_encoding[states[0]] =
var;
778 }
else if (states.size() > 2) {
779 struct UniqueDetector {
780 void Set(int64_t v) {
781 if (!is_unique)
return;
783 if (v !=
value) is_unique =
false;
790 bool is_unique =
true;
796 absl::flat_hash_map<int64_t, UniqueDetector> out_to_in;
797 absl::flat_hash_map<int64_t, UniqueDetector> out_to_transition;
798 for (
int i = 0; i < num_tuples; ++i) {
799 out_to_in[out_states[i]].Set(in_states[i]);
800 out_to_transition[out_states[i]].Set(labels[i]);
803 for (
const int64_t state : states) {
806 if (!in_encoding.empty() && out_to_in[state].is_unique) {
807 const int64_t unique_in = out_to_in[state].value;
808 if (in_count[unique_in] == out_count[state]) {
809 out_encoding[state] = in_encoding[unique_in];
816 if (!encoding.empty() && out_to_transition[state].is_unique) {
817 const int64_t unique_transition = out_to_transition[state].value;
818 if (transition_count[unique_transition] == out_count[state]) {
819 out_encoding[state] = encoding[unique_transition];
824 out_encoding[state] =
context->NewBoolVar();
842 const int num_involved_variables =
843 in_encoding.size() + encoding.size() + out_encoding.size();
844 const bool use_light_encoding = (num_tuples > num_involved_variables);
845 if (use_light_encoding && !in_encoding.empty() && !encoding.empty() &&
846 !out_encoding.empty()) {
850 absl::flat_hash_map<int64_t, std::vector<int64_t>> in_to_label;
851 absl::flat_hash_map<int64_t, std::vector<int64_t>> in_to_out;
852 for (
int i = 0; i < num_tuples; ++i) {
853 in_to_label[in_states[i]].push_back(labels[i]);
854 in_to_out[in_states[i]].push_back(out_states[i]);
856 for (
const auto [in_value, in_literal] : in_encoding) {
857 AddImplyInReachableValues(in_literal, in_to_label[in_value], encoding,
859 AddImplyInReachableValues(in_literal, in_to_out[in_value], out_encoding,
864 for (
int i = 0; i < num_tuples; ++i) {
866 context->working_model->add_constraints()->mutable_bool_or();
867 bool_or->add_literals(
NegatedRef(in_encoding.at(in_states[i])));
868 bool_or->add_literals(
NegatedRef(encoding.at(labels[i])));
869 bool_or->add_literals(out_encoding.at(out_states[i]));
872 in_encoding.swap(out_encoding);
873 out_encoding.clear();
881 std::vector<int> tuple_literals;
882 if (num_tuples == 2) {
883 const int bool_var =
context->NewBoolVar();
884 tuple_literals.push_back(bool_var);
885 tuple_literals.push_back(
NegatedRef(bool_var));
890 BoolArgumentProto* exactly_one =
891 context->working_model->add_constraints()->mutable_exactly_one();
892 for (
int i = 0; i < num_tuples; ++i) {
894 if (in_count[in_states[i]] == 1 && !in_encoding.empty()) {
895 tuple_literal = in_encoding[in_states[i]];
896 }
else if (transition_count[labels[i]] == 1 && !encoding.empty()) {
897 tuple_literal = encoding[labels[i]];
898 }
else if (out_count[out_states[i]] == 1 && !out_encoding.empty()) {
899 tuple_literal = out_encoding[out_states[i]];
901 tuple_literal =
context->NewBoolVar();
904 tuple_literals.push_back(tuple_literal);
905 exactly_one->add_literals(tuple_literal);
909 if (!in_encoding.empty()) {
910 LinkLiteralsAndValues(tuple_literals, in_states, in_encoding,
context);
912 if (!encoding.empty()) {
913 LinkLiteralsAndValues(tuple_literals, labels, encoding,
context);
915 if (!out_encoding.empty()) {
916 LinkLiteralsAndValues(tuple_literals, out_states, out_encoding,
context);
919 in_encoding.swap(out_encoding);
920 out_encoding.clear();
923 if (removed_values) {
924 context->UpdateRuleStats(
"automaton: reduced variable domains");
926 context->UpdateRuleStats(
"automaton: expanded");
930void ExpandNegativeTable(ConstraintProto*
ct, PresolveContext*
context) {
931 TableConstraintProto& table = *
ct->mutable_table();
932 const int num_vars = table.vars_size();
933 const int num_original_tuples = table.values_size() / num_vars;
934 std::vector<std::vector<int64_t>> tuples(num_original_tuples);
936 for (
int i = 0; i < num_original_tuples; ++i) {
937 for (
int j = 0; j < num_vars; ++j) {
938 tuples[i].push_back(table.values(count++));
942 if (tuples.empty()) {
943 context->UpdateRuleStats(
"table: empty negated constraint");
950 std::vector<int64_t> domain_sizes;
951 for (
int i = 0; i < num_vars; ++i) {
952 domain_sizes.push_back(
context->DomainOf(table.vars(i)).Size());
957 std::vector<int> clause;
958 for (
const std::vector<int64_t>& tuple : tuples) {
960 for (
int i = 0; i < num_vars; ++i) {
961 const int64_t
value = tuple[i];
962 if (
value == any_value)
continue;
965 context->GetOrCreateVarValueEncoding(table.vars(i),
value);
970 BoolArgumentProto* bool_or =
971 context->working_model->add_constraints()->mutable_bool_or();
972 for (
const int lit : clause) {
973 bool_or->add_literals(lit);
976 context->UpdateRuleStats(
"table: expanded negated constraint");
986void ProcessOneVariable(
const std::vector<int>& tuple_literals,
987 const std::vector<int64_t>& values,
int variable,
988 int64_t any_value, PresolveContext*
context) {
989 VLOG(2) <<
"Process var(" << variable <<
") with domain "
990 <<
context->DomainOf(variable) <<
" and " << values.size()
992 CHECK_EQ(tuple_literals.size(), values.size());
995 std::vector<int> tuples_with_any;
996 std::vector<std::pair<int64_t, int>> pairs;
997 for (
int i = 0; i < values.size(); ++i) {
998 const int64_t
value = values[i];
999 if (
value == any_value) {
1000 tuples_with_any.push_back(tuple_literals[i]);
1004 pairs.emplace_back(
value, tuple_literals[i]);
1009 std::vector<int> selected;
1010 std::sort(pairs.begin(), pairs.end());
1011 for (
int i = 0; i < pairs.size();) {
1013 const int64_t
value = pairs[i].first;
1014 for (; i < pairs.size() && pairs[i].first ==
value; ++i) {
1015 selected.push_back(pairs[i].second);
1018 CHECK(!selected.empty() || !tuples_with_any.empty());
1019 if (selected.size() == 1 && tuples_with_any.empty()) {
1020 context->InsertVarValueEncoding(selected.front(), variable,
value);
1022 const int value_literal =
1024 BoolArgumentProto* no_support =
1025 context->working_model->add_constraints()->mutable_bool_or();
1026 for (
const int lit : selected) {
1027 no_support->add_literals(lit);
1028 context->AddImplication(lit, value_literal);
1030 for (
const int lit : tuples_with_any) {
1031 no_support->add_literals(lit);
1035 no_support->add_literals(
NegatedRef(value_literal));
1041void AddSizeTwoTable(
1042 const std::vector<int>& vars,
1043 const std::vector<std::vector<int64_t>>& tuples,
1044 const std::vector<absl::flat_hash_set<int64_t>>& values_per_var,
1047 const int left_var = vars[0];
1048 const int right_var = vars[1];
1049 if (
context->DomainOf(left_var).IsFixed() ||
1050 context->DomainOf(right_var).IsFixed()) {
1056 std::map<int, std::vector<int>> left_to_right;
1057 std::map<int, std::vector<int>> right_to_left;
1059 for (
const auto& tuple : tuples) {
1060 const int64_t left_value(tuple[0]);
1061 const int64_t right_value(tuple[1]);
1063 CHECK(
context->DomainContains(right_var, right_value));
1065 const int left_literal =
1066 context->GetOrCreateVarValueEncoding(left_var, left_value);
1067 const int right_literal =
1068 context->GetOrCreateVarValueEncoding(right_var, right_value);
1069 left_to_right[left_literal].push_back(right_literal);
1070 right_to_left[right_literal].push_back(left_literal);
1073 int num_implications = 0;
1074 int num_clause_added = 0;
1075 int num_large_clause_added = 0;
1076 auto add_support_constraint =
1077 [
context, &num_clause_added, &num_large_clause_added, &num_implications](
1078 int lit,
const std::vector<int>& support_literals,
1079 int max_support_size) {
1080 if (support_literals.size() == max_support_size)
return;
1081 if (support_literals.size() == 1) {
1082 context->AddImplication(lit, support_literals.front());
1085 BoolArgumentProto* bool_or =
1086 context->working_model->add_constraints()->mutable_bool_or();
1087 for (
const int support_literal : support_literals) {
1088 bool_or->add_literals(support_literal);
1092 if (support_literals.size() > max_support_size / 2) {
1093 num_large_clause_added++;
1098 for (
const auto& it : left_to_right) {
1099 add_support_constraint(it.first, it.second, values_per_var[1].size());
1101 for (
const auto& it : right_to_left) {
1102 add_support_constraint(it.first, it.second, values_per_var[0].size());
1104 VLOG(2) <<
"Table: 2 variables, " << tuples.size() <<
" tuples encoded using "
1105 << num_clause_added <<
" clauses, including "
1106 << num_large_clause_added <<
" large clauses, " << num_implications
1110void ExpandPositiveTable(ConstraintProto*
ct, PresolveContext*
context) {
1111 const TableConstraintProto& table =
ct->table();
1112 const int num_vars = table.vars_size();
1113 const int num_original_tuples = table.values_size() / num_vars;
1116 const std::vector<int> vars(table.vars().begin(), table.vars().end());
1117 std::vector<std::vector<int64_t>> tuples(num_original_tuples);
1119 for (
int tuple_index = 0; tuple_index < num_original_tuples; ++tuple_index) {
1120 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1121 tuples[tuple_index].push_back(table.values(count++));
1127 std::vector<absl::flat_hash_set<int64_t>> values_per_var(num_vars);
1129 for (
int tuple_index = 0; tuple_index < num_original_tuples; ++tuple_index) {
1131 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1132 const int64_t
value = tuples[tuple_index][var_index];
1133 if (!
context->DomainContains(vars[var_index],
value)) {
1139 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1140 values_per_var[var_index].insert(tuples[tuple_index][var_index]);
1142 std::swap(tuples[tuple_index], tuples[new_size]);
1146 tuples.resize(new_size);
1147 const int num_valid_tuples = tuples.size();
1149 if (tuples.empty()) {
1150 context->UpdateRuleStats(
"table: empty");
1151 return (
void)
context->NotifyThatModelIsUnsat();
1157 int num_fixed_variables = 0;
1158 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1162 values_per_var[var_index].end()})));
1163 if (
context->DomainOf(vars[var_index]).IsFixed()) {
1164 num_fixed_variables++;
1168 if (num_fixed_variables == num_vars - 1) {
1169 context->UpdateRuleStats(
"table: one variable not fixed");
1172 }
else if (num_fixed_variables == num_vars) {
1173 context->UpdateRuleStats(
"table: all variables fixed");
1179 if (num_vars == 2) {
1180 AddSizeTwoTable(vars, tuples, values_per_var,
context);
1182 "table: expanded positive constraint with two variables");
1189 int num_prefix_tuples = 0;
1191 absl::flat_hash_set<absl::Span<const int64_t>> prefixes;
1192 for (
const std::vector<int64_t>& tuple : tuples) {
1193 prefixes.insert(absl::MakeSpan(tuple.data(), num_vars - 1));
1195 num_prefix_tuples = prefixes.size();
1202 std::vector<int64_t> domain_sizes;
1203 for (
int i = 0; i < num_vars; ++i) {
1204 domain_sizes.push_back(values_per_var[i].size());
1206 const int num_tuples_before_compression = tuples.size();
1208 const int num_compressed_tuples = tuples.size();
1209 if (num_compressed_tuples < num_tuples_before_compression) {
1210 context->UpdateRuleStats(
"table: compress tuples");
1213 if (num_compressed_tuples == 1) {
1215 context->UpdateRuleStats(
"table: one tuple");
1221 const bool prefixes_are_all_different = num_prefix_tuples == num_valid_tuples;
1222 if (prefixes_are_all_different) {
1224 "TODO table: last value implied by previous values");
1234 int64_t max_num_prefix_tuples = 1;
1235 for (
int var_index = 0; var_index + 1 < num_vars; ++var_index) {
1236 max_num_prefix_tuples =
1237 CapProd(max_num_prefix_tuples, values_per_var[var_index].size());
1241 absl::StrCat(
"Table: ", num_vars,
1242 " variables, original tuples = ", num_original_tuples);
1243 if (num_valid_tuples != num_original_tuples) {
1244 absl::StrAppend(&
message,
", valid tuples = ", num_valid_tuples);
1246 if (prefixes_are_all_different) {
1247 if (num_prefix_tuples < max_num_prefix_tuples) {
1248 absl::StrAppend(&
message,
", partial prefix = ", num_prefix_tuples,
"/",
1249 max_num_prefix_tuples);
1251 absl::StrAppend(&
message,
", full prefix = true");
1254 absl::StrAppend(&
message,
", num prefix tuples = ", num_prefix_tuples);
1256 if (num_compressed_tuples != num_valid_tuples) {
1258 ", compressed tuples = ", num_compressed_tuples);
1264 if (num_compressed_tuples == 2) {
1265 context->UpdateRuleStats(
"TODO table: two tuples");
1270 std::vector<int> tuple_literals(num_compressed_tuples);
1271 BoolArgumentProto* exactly_one =
1272 context->working_model->add_constraints()->mutable_exactly_one();
1273 for (
int i = 0; i < num_compressed_tuples; ++i) {
1274 tuple_literals[i] =
context->NewBoolVar();
1275 exactly_one->add_literals(tuple_literals[i]);
1278 std::vector<int64_t> values(num_compressed_tuples);
1279 for (
int var_index = 0; var_index < num_vars; ++var_index) {
1280 if (values_per_var[var_index].size() == 1)
continue;
1281 for (
int i = 0; i < num_compressed_tuples; ++i) {
1282 values[i] = tuples[i][var_index];
1284 ProcessOneVariable(tuple_literals, values, vars[var_index], any_value,
1288 context->UpdateRuleStats(
"table: expanded positive constraint");
1292bool AllDiffShouldBeExpanded(
const Domain& union_of_domains,
1293 ConstraintProto*
ct, PresolveContext*
context) {
1294 const AllDifferentConstraintProto&
proto = *
ct->mutable_all_diff();
1295 const int num_exprs =
proto.exprs_size();
1296 int num_fully_encoded = 0;
1297 for (
int i = 0; i < num_exprs; ++i) {
1299 num_fully_encoded++;
1303 if ((union_of_domains.Size() <= 2 *
proto.exprs_size()) ||
1304 (union_of_domains.Size() <= 32)) {
1309 if (num_fully_encoded == num_exprs && union_of_domains.Size() < 256) {
1316void ExpandAllDiff(
bool force_alldiff_expansion, ConstraintProto*
ct,
1318 AllDifferentConstraintProto&
proto = *
ct->mutable_all_diff();
1319 if (
proto.exprs_size() <= 1)
return;
1321 const int num_exprs =
proto.exprs_size();
1322 Domain union_of_domains =
context->DomainSuperSetOf(
proto.exprs(0));
1323 for (
int i = 1; i < num_exprs; ++i) {
1325 union_of_domains.UnionWith(
context->DomainSuperSetOf(
proto.exprs(i)));
1328 if (!AllDiffShouldBeExpanded(union_of_domains,
ct,
context) &&
1329 !force_alldiff_expansion) {
1333 const bool is_a_permutation = num_exprs == union_of_domains.Size();
1338 for (
const int64_t v : union_of_domains.Values()) {
1340 std::vector<LinearExpressionProto> possible_exprs;
1341 int fixed_expression_count = 0;
1342 for (
const LinearExpressionProto& expr :
proto.exprs()) {
1343 if (!
context->DomainContains(expr, v))
continue;
1344 possible_exprs.push_back(expr);
1346 fixed_expression_count++;
1350 if (fixed_expression_count > 1) {
1352 return (
void)
context->NotifyThatModelIsUnsat();
1353 }
else if (fixed_expression_count == 1) {
1355 for (
const LinearExpressionProto& expr : possible_exprs) {
1356 if (
context->IsFixed(expr))
continue;
1357 if (!
context->IntersectDomainWith(expr, Domain(v).Complement())) {
1358 VLOG(1) <<
"Empty domain for a variable in ExpandAllDiff()";
1364 BoolArgumentProto* at_most_or_equal_one =
1366 ?
context->working_model->add_constraints()->mutable_exactly_one()
1367 :
context->working_model->add_constraints()->mutable_at_most_one();
1368 for (
const LinearExpressionProto& expr : possible_exprs) {
1371 const int encoding =
context->GetOrCreateAffineValueEncoding(expr, v);
1372 at_most_or_equal_one->add_literals(encoding);
1375 if (is_a_permutation) {
1376 context->UpdateRuleStats(
"all_diff: permutation expanded");
1378 context->UpdateRuleStats(
"all_diff: expanded");
1389void ExpandSomeLinearOfSizeTwo(ConstraintProto*
ct, PresolveContext*
context) {
1390 const LinearConstraintProto& arg =
ct->linear();
1391 if (arg.vars_size() != 2)
return;
1393 const int var1 = arg.vars(0);
1394 const int var2 = arg.vars(1);
1397 const int64_t coeff1 = arg.coeffs(0);
1398 const int64_t coeff2 = arg.coeffs(1);
1400 const Domain reachable_rhs_superset =
1401 context->DomainOf(var1).MultiplicationBy(coeff1).AdditionWith(
1402 context->DomainOf(var2).MultiplicationBy(coeff2));
1404 const Domain infeasible_reachable_values =
1405 reachable_rhs_superset.IntersectionWith(
1409 if (infeasible_reachable_values.Size() != 1)
return;
1414 int64_t cte = infeasible_reachable_values.FixedValue();
1419 context->UpdateRuleStats(
"linear: expand always feasible ax + by != cte");
1423 const Domain reduced_domain =
1425 .AdditionWith(Domain(-x0))
1426 .InverseMultiplicationBy(
b)
1427 .IntersectionWith(
context->DomainOf(var2)
1428 .AdditionWith(Domain(-y0))
1429 .InverseMultiplicationBy(-
a));
1431 if (reduced_domain.Size() > 16)
return;
1436 const int64_t size1 =
context->DomainOf(var1).Size();
1437 const int64_t size2 =
context->DomainOf(var2).Size();
1438 for (
const int64_t z : reduced_domain.Values()) {
1439 const int64_t value1 = x0 +
b * z;
1440 const int64_t value2 = y0 -
a * z;
1441 DCHECK(
context->DomainContains(var1, value1)) <<
"value1 = " << value1;
1442 DCHECK(
context->DomainContains(var2, value2)) <<
"value2 = " << value2;
1443 DCHECK_EQ(coeff1 * value1 + coeff2 * value2,
1444 infeasible_reachable_values.FixedValue());
1446 if (!
context->HasVarValueEncoding(var1, value1,
nullptr) || size1 == 2) {
1449 if (!
context->HasVarValueEncoding(var2, value2,
nullptr) || size2 == 2) {
1456 for (
const int64_t z : reduced_domain.Values()) {
1457 const int64_t value1 = x0 +
b * z;
1458 const int64_t value2 = y0 -
a * z;
1460 const int lit1 =
context->GetOrCreateVarValueEncoding(var1, value1);
1461 const int lit2 =
context->GetOrCreateVarValueEncoding(var2, value2);
1463 context->working_model->add_constraints()->mutable_bool_or();
1466 for (
const int lit :
ct->enforcement_literal()) {
1471 context->UpdateRuleStats(
"linear: expand small ax + by != cte");
1478 if (
context->params().disable_constraint_expansion())
return;
1479 if (
context->ModelIsUnsat())
return;
1483 if (
context->ModelIsExpanded())
return;
1486 context->InitializeNewDomains();
1489 context->ClearPrecedenceCache();
1492 for (
int i = 0; i <
context->working_model->constraints_size(); ++i) {
1495 switch (
ct->constraint_case()) {
1496 case ConstraintProto::ConstraintCase::kReservoir:
1499 case ConstraintProto::ConstraintCase::kIntMod:
1502 case ConstraintProto::ConstraintCase::kIntProd:
1505 case ConstraintProto::ConstraintCase::kElement:
1508 case ConstraintProto::ConstraintCase::kInverse:
1511 case ConstraintProto::ConstraintCase::kAutomaton:
1514 case ConstraintProto::ConstraintCase::kTable:
1515 if (
ct->table().negated()) {
1528 context->UpdateNewConstraintsVariableUsage();
1530 context->UpdateConstraintVariableUsage(i);
1534 if (
context->ModelIsUnsat()) {
1543 for (
int i = 0; i <
context->working_model->constraints_size(); ++i) {
1546 switch (
ct->constraint_case()) {
1547 case ConstraintProto::ConstraintCase::kAllDiff:
1548 ExpandAllDiff(
context->params().expand_alldiff_constraints(),
ct,
1551 case ConstraintProto::ConstraintCase::kLinear:
1562 context->UpdateNewConstraintsVariableUsage();
1564 context->UpdateConstraintVariableUsage(i);
1568 if (
context->ModelIsUnsat()) {
1578 context->ClearPrecedenceCache();
1581 context->InitializeNewDomains();
1584 for (
int i = 0; i <
context->working_model->variables_size(); ++i) {
1586 context->working_model->mutable_variables(i));
1589 context->NotifyThatModelIsExpanded();
#define CHECK_EQ(val1, val2)
#define DCHECK(condition)
#define DCHECK_EQ(val1, val2)
#define VLOG(verboselevel)
Domain Complement() const
Returns the set Int64 ∖ D.
static Domain FromValues(std::vector< int64_t > values)
Creates a domain from the union of an unsorted list of integer values.
friend void swap(CpModelProto &a, CpModelProto &b)
GurobiMPCallbackContext * context
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
void swap(IdMap< K, V > &a, IdMap< K, V > &b)
bool RefIsPositive(int ref)
void CompressTuples(absl::Span< const int64_t > domain_sizes, int64_t any_value, std::vector< std::vector< int64_t > > *tuples)
void ExpandCpModel(PresolveContext *context)
bool SolveDiophantineEquationOfSizeTwo(int64_t &a, int64_t &b, int64_t &cte, int64_t &x0, int64_t &y0)
void FillDomainInProto(const Domain &domain, ProtoWithDomain *proto)
Domain ReadDomainFromProto(const ProtoWithDomain &proto)
void AddLinearExpressionToLinearConstraint(const LinearExpressionProto &expr, int64_t coefficient, LinearConstraintProto *linear)
Collection of objects used to extend the Constraint Solver library.
int64_t CapAdd(int64_t x, int64_t y)
int64_t CapSub(int64_t x, int64_t y)
std::string ProtobufShortDebugString(const P &message)
int64_t CapProd(int64_t x, int64_t y)
#define SOLVER_LOG(logger,...)
#define VLOG_IS_ON(verboselevel)