diff --git a/src/flatzinc/model.cc b/src/flatzinc/model.cc index b2004f0c8a..3e5e554ebb 100644 --- a/src/flatzinc/model.cc +++ b/src/flatzinc/model.cc @@ -216,6 +216,67 @@ bool Domain::Contains(int64 value) const { } } +namespace { +bool IntervalOverlapValues(int64 lb, int64 ub, const std::vector& values) { + for (int64 value : values) { + if (lb <= value && value <= ub) { + return true; + } + } + return false; +} +} // namespace + +bool Domain::OverlapsIntList(const std::vector& vec) const { + if (IsAllInt64()) { + return true; + } + if (is_interval) { + CHECK(!values.empty()); + return IntervalOverlapValues(values[0], values[1], vec); + } else { + // TODO(user): Better algorithm, sort and compare increasingly. + const std::vector& to_scan = + values.size() <= vec.size() ? values : vec; + const std::unordered_set container = + values.size() <= vec.size() + ? std::unordered_set(vec.begin(), vec.end()) + : std::unordered_set(values.begin(), values.end()); + for (int64 value : to_scan) { + if (ContainsKey(container, value)) { + return true; + } + } + return false; + } +} + +bool Domain::OverlapsIntInterval(int64 lb, int64 ub) const { + if (IsAllInt64()) { + return true; + } + if (is_interval) { + CHECK(!values.empty()); + const int64 dlb = values[0]; + const int64 dub = values[1]; + return !(dub < lb || dlb > ub); + } else { + return IntervalOverlapValues(lb, ub, values); + } +} + +bool Domain::OverlapsDomain(const Domain& other) const { + if (other.is_interval) { + if (other.values.empty()) { + return true; + } else { + return OverlapsIntInterval(other.values[0], other.values[1]); + } + } else { + return OverlapsIntList(other.values); + } +} + bool Domain::RemoveValue(int64 value) { if (is_interval) { if (values.empty()) { diff --git a/src/flatzinc/model.h b/src/flatzinc/model.h index 0efb2f6251..e735584ca3 100644 --- a/src/flatzinc/model.h +++ b/src/flatzinc/model.h @@ -67,7 +67,11 @@ struct Domain { // Returns true if the domain is [kint64min..kint64max] bool IsAllInt64() const; + // Various inclusion tests on a domain. bool Contains(int64 value) const; + bool OverlapsIntList(const std::vector& values) const; + bool OverlapsIntInterval(int64 lb, int64 ub) const; + bool OverlapsDomain(const Domain& other) const; // All the following modifiers change the internal representation // list to interval or interval to list. diff --git a/src/flatzinc/presolve.cc b/src/flatzinc/presolve.cc index 42cc447049..bca0e3277f 100644 --- a/src/flatzinc/presolve.cc +++ b/src/flatzinc/presolve.cc @@ -81,6 +81,63 @@ std::unordered_set GetValueSet(const Argument& arg) { return result; } +void SetConstraintAsIntEq(Constraint* ct, IntegerVariable* var, int64 value) { + ct->type = "int_eq"; + ct->arguments.clear(); + ct->arguments.push_back(Argument::IntVarRef(var)); + ct->arguments.push_back(Argument::IntegerValue(value)); +} + +bool OverlapsAt(const Argument& array, int pos, const Argument& other) { + if (array.type == Argument::INT_VAR_REF_ARRAY) { + const Domain& domain = array.variables[pos]->domain; + if (domain.IsAllInt64()) { + return true; + } + switch (other.type) { + case Argument::INT_VALUE: { + return domain.Contains(other.Value()); + } + case Argument::INT_INTERVAL: { + return domain.OverlapsIntInterval(other.values[0], other.values[1]); + } + case Argument::INT_LIST: { + return domain.OverlapsIntList(other.values); + } + case Argument::INT_VAR_REF: { + return domain.OverlapsDomain(other.variables[0]->domain); + } + default: { + LOG(FATAL) << "Case not supported in OverlapsAt"; + return false; + } + } + } else if (array.type == Argument::INT_LIST) { + const int64 value = array.values[pos]; + switch (other.type) { + case Argument::INT_VALUE: { + return value == other.values[0]; + } + case Argument::INT_INTERVAL: { + return other.values[0] <= value && value <= other.values[1]; + } + case Argument::INT_LIST: { + return std::find(other.values.begin(), other.values.end(), value) != + other.values.end(); + } + case Argument::INT_VAR_REF: { + return other.variables[0]->domain.Contains(value); + } + default: { + LOG(FATAL) << "Case not supported in OverlapsAt"; + return false; + } + } + } else { + LOG(FATAL) << "First argument not supported in OverlapsAt"; + return false; + } +} } // namespace // For the author's reference, here is an indicative list of presolve rules @@ -580,6 +637,17 @@ bool Presolver::PresolveIntTimes(Constraint* ct, std::string* log) { // TODO(user): Treat overflow correctly. } } + + // Special case: multiplication by zero. + if ((ct->arguments[0].HasOneValue() && ct->arguments[0].Value() == 0) || + (ct->arguments[1].HasOneValue() && ct->arguments[1].Value() == 0)) { + ct->type = "int_eq"; + ct->arguments[0] = ct->arguments[2]; + ct->arguments.resize(1); + ct->arguments.push_back(Argument::IntegerValue(0)); + return true; + } + return false; } @@ -865,25 +933,60 @@ bool Presolver::PresolveIntLinLt(Constraint* ct, std::string* log) { // with (c1 == 1 or c2 % c1 == 0) and xx = eq, ne, lt, le, gt, ge // Output: int_xx(x, c2 / c1) and int_xx_reif(x, c2 / c1, b) bool Presolver::SimplifyUnaryLinear(Constraint* ct, std::string* log) { - if (!ct->arguments[0].HasOneValue()) { + if (!ct->arguments[0].HasOneValue() || + ct->arguments[1].variables.size() != 1) { return false; } - const int64 coefficient = ct->arguments[0].Value(); + const int64 coeff = ct->arguments[0].values.front(); const int64 rhs = ct->arguments[2].Value(); - if (coefficient == 1 || (coefficient > 0 && rhs % coefficient == 0)) { - // TODO(user): Support coefficient = 0. + IntegerVariable* const var = ct->arguments[1].variables.front(); + const std::string op = ct->type.substr(8, 2); + bool changed = false; + int64 new_rhs = 0; + + if (coeff == 0) { + ct->arguments[0].values.clear(); + ct->arguments[1].variables.clear(); + // Will be process by PresolveLinear. + return true; + } + + if (op == "eq") { + if (rhs % coeff == 0) { + changed = true; + new_rhs = rhs / coeff; + } else { // Equality always false. + if (ct->arguments.size() == 4) { // reified version. + SetConstraintAsIntEq(ct, ct->arguments[3].Var(), 0); + } else { + ct->SetAsFalse(); + } + return true; + } + } else if (op == "ne") { + if (rhs % coeff == 0) { + changed = true; + new_rhs = rhs / coeff; + } else { // Equality always true. + if (ct->arguments.size() == 4) { // reified version. + SetConstraintAsIntEq(ct, ct->arguments[3].Var(), 1); + } else { + ct->MarkAsInactive(); + } + return true; + } + } else if (coeff >= 0 && rhs % coeff == 0) { // TODO(user): Support coefficient < 0 (and reverse the inequalities). - // TODO(user): Support rhs % coefficient != 0, and no the correct - // rounding in the case of inequalities, of false model in the case of - // equalities. + // TODO(user): Support rhs % coefficient != 0, and do the correct + // rounding in the case of inequalities. + changed = true; + new_rhs = rhs / coeff; + } + if (changed) { log->append("remove linear part"); // transform arguments. - ct->arguments[0].type = Argument::INT_VAR_REF; - ct->arguments[0].values.clear(); - ct->arguments[0].variables.push_back(ct->arguments[1].variables[0]); - ct->arguments[1].type = Argument::INT_VALUE; - ct->arguments[1].variables.clear(); - ct->arguments[1].values.push_back(rhs / coefficient); + ct->arguments[0] = Argument::IntVarRef(var); + ct->arguments[1] = Argument::IntegerValue(new_rhs); ct->RemoveArg(2); // Change type (remove "_lin" part). DCHECK(ct->type.size() >= 8 && ct->type.substr(3, 4) == "_lin"); @@ -1132,10 +1235,14 @@ bool Presolver::PresolveArrayIntElement(Constraint* ct, std::string* log) { // with yyy is the opposite of xxx (eq -> eq, ne -> ne, le -> ge, // lt -> gt, ge -> le, gt -> lt) // -// Rule 2: +// Rule 2a: // Input: int_lin_xxx[[c1, .., cn], [c'1, .., c'n], c0] (no variables) // Output: inactive or false constraint. // +// Rule 2b: +// Input: int_lin_xxx[[], [], c0] or int_lin_xxx_reif([], [], c0) +// Output: inactive or false constraint. +// // Rule 3: // Input: int_lin_xxx_reif[[c1, .., cn], [c'1, .., c'n], c0] (no variables) // Output: bool_eq(c0, true or false). @@ -1147,11 +1254,8 @@ bool Presolver::PresolveArrayIntElement(Constraint* ct, std::string* log) { // // TODO(user): The code is broken in case of integer-overflow. bool Presolver::PresolveLinear(Constraint* ct, std::string* log) { - if (ct->arguments[0].values.empty()) { - return false; - } - // Rule 2. - if (ct->arguments[1].IsArrayOfValues()) { + // Rule 2a and 2b. + if (ct->arguments[0].values.empty() || ct->arguments[1].IsArrayOfValues()) { log->append("rewrite constant linear equation"); int64 scalprod = 0; for (int i = 0; i < ct->arguments[0].values.size(); ++i) { @@ -1219,6 +1323,11 @@ bool Presolver::PresolveLinear(Constraint* ct, std::string* log) { return true; } + if (ct->arguments[0].values.empty()) { + return false; + } + // From now on, we assume the linear part is not empty. + // Rule 4. if (!ct->arguments[1].variables.empty()) { const int size = ct->arguments[1].variables.size(); @@ -1384,6 +1493,78 @@ bool Presolver::PropagatePositiveLinear(Constraint* ct, std::string* log) { return modified; } +// Input: int_lin_xx([c1, .., cn], [x1, .., xn], rhs) with ci >= 0. +// +// Computes the bounds on the rhs. +// Rule1: remove always true/false constraint or fix the reif Boolean. +// Rule2: transform ne/eq to gt/ge/lt/le if rhs is at one bound of its domain. +bool Presolver::SimplifyPositiveLinear(Constraint* ct, std::string* log) { + const int64 rhs = ct->arguments[2].Value(); + if (ct->presolve_propagation_done || rhs < 0 || + ct->arguments[1].variables.empty()) { + return false; + } + int64 rhs_min = 0; + int64 rhs_max = 0; + const int n = ct->arguments[0].values.size(); + for (int i = 0; i < n; ++i) { + const int64 coeff = ct->arguments[0].values[i]; + if (coeff < 0) return false; + rhs_min += coeff * ct->arguments[1].variables[i]->domain.Min(); + rhs_max += coeff * ct->arguments[1].variables[i]->domain.Max(); + } + if (rhs < rhs_min || rhs > rhs_max) { + if (ct->type == "int_lin_eq") { + ct->SetAsFalse(); + return true; + } else if (ct->type == "int_lin_eq_reif") { + ct->type = "bool_eq"; + ct->arguments[0] = ct->arguments[3]; + ct->arguments.resize(1); + ct->arguments.push_back(Argument::IntegerValue(0)); + return true; + } else if (ct->type == "int_lin_ne") { + ct->MarkAsInactive(); + return true; + } else if (ct->type == "int_lin_ne_reif") { + ct->type = "bool_eq"; + ct->arguments[0] = ct->arguments[3]; + ct->arguments.resize(1); + ct->arguments.push_back(Argument::IntegerValue(1)); + return true; + } + } else if (rhs == rhs_min) { + if (ct->type == "int_lin_eq") { + ct->type = "int_lin_le"; + return true; + } else if (ct->type == "int_lin_eq_reif") { + ct->type = "int_lin_le_reif"; + return true; + } else if (ct->type == "int_lin_ne") { + ct->type = "int_lin_gt"; + return true; + } else if (ct->type == "int_lin_ne_reif") { + ct->type = "int_lin_gt_reif"; + return true; + } + } else if (rhs == rhs_max) { + if (ct->type == "int_lin_eq") { + ct->type = "int_lin_ge"; + return true; + } else if (ct->type == "int_lin_eq_reif") { + ct->type = "int_lin_ge_reif"; + return true; + } else if (ct->type == "int_lin_ne") { + ct->type = "int_lin_lt"; + return true; + } else if (ct->type == "int_lin_ne_reif") { + ct->type = "int_lin_lt_reif"; + return true; + } + } + return false; +} + // Minizinc flattens 2d element constraints (x = A[y][z]) into 1d element // constraint with an affine mapping between y, z and the new index. // This rule stores the mapping to reconstruct the 2d element constraint. @@ -1512,9 +1693,7 @@ bool Presolver::PresolveSimplifyElement(Constraint* ct, std::string* log) { const int64 index = ct->arguments[0].Value() - 1; const int64 value = ct->arguments[1].values[index]; // Rewrite as equality. - ct->type = "int_eq"; - ct->arguments[0] = Argument::IntegerValue(value); - ct->RemoveArg(1); + SetConstraintAsIntEq(ct, ct->arguments[2].Var(), value); return true; } @@ -1790,32 +1969,39 @@ bool Presolver::PresolveSimplifyExprElement(Constraint* ct, std::string* log) { } // Rule 4. - if (ct->arguments[0].IsVariable() && ct->arguments[2].HasOneValue()) { - bool changed = false; - const int64 value = ct->arguments[2].Value(); - const std::vector& vars = ct->arguments[1].variables; - for (int index = 1; index <= vars.size(); ++index) { - if (!ct->arguments[0].Var()->domain.Contains(index)) continue; - const Domain& x = vars[index - 1]->domain; // 1-based in minizinc. - if (!x.Contains(value)) { - ct->arguments[0].Var()->domain.RemoveValue(index); - changed = true; + if (ct->arguments[0].IsVariable()) { + const Domain& domain = ct->arguments[0].Var()->domain; + std::vector to_keep; + const int array_size = ct->arguments[1].variables.size(); + bool remove_some = false; + if (domain.is_interval) { + for (int64 v = std::max(1, domain.values[0]); + v <= std::min(array_size, domain.values[1]); ++v) { + if (OverlapsAt(ct->arguments[1], v - 1, ct->arguments[2])) { + to_keep.push_back(v); + } else { + remove_some = true; + } + } + } else { + for (int64 v : domain.values) { + if (v < 1 || v > array_size) { + remove_some = true; + } else { + if (OverlapsAt(ct->arguments[1], v - 1, ct->arguments[2])) { + to_keep.push_back(v); + } else { + remove_some = true; + } + } } } - return changed; - } else if (ct->arguments[0].IsVariable() && ct->arguments[2].IsVariable()) { - bool changed = false; - const std::vector& vars = ct->arguments[1].variables; - for (int index = 1; index <= vars.size(); ++index) { - if (!ct->arguments[0].Var()->domain.Contains(index)) continue; - const Domain& x = vars[index - 1]->domain; // 1-based in minizinc. - const Domain& y = ct->arguments[2].Var()->domain; - if (x.Min() > y.Max() || y.Min() > x.Max()) { - ct->arguments[0].Var()->domain.RemoveValue(index); - changed = true; - } + if (remove_some) { + ct->arguments[0].Var()->domain.IntersectWithListOfIntegers(to_keep); + log->append(StringPrintf("reduce index domain to %s", + ct->arguments[0].Var()->DebugString().c_str())); + return true; } - return changed; } return false; @@ -1915,6 +2101,7 @@ bool Presolver::PropagateReifiedComparisons(Constraint* ct, std::string* log) { ct->arguments.push_back(Argument::IntVarRef(var)); ct->arguments.push_back(target); ct->type = parity ? "bool_eq" : "bool_not"; + return true; } else { // Rule 3. int state = 2; // 0 force_false, 1 force true, 2 unknown. @@ -1985,10 +2172,11 @@ bool Presolver::PropagateReifiedComparisons(Constraint* ct, std::string* log) { const Domain& rd = ct->arguments[1].Var()->domain; int state = 2; // 0 force_false, 1 force true, 2 unknown. if (id == "int_eq_reif" || id == "bool_eq_reif") { - if (ld.Min() > rd.Max() || ld.Max() < rd.Min()) { + if (!ld.OverlapsDomain(rd)) { state = 0; } } else if (id == "int_ne_reif" || id == "bool_ne_reif") { + // TODO(user): Test if the domain are disjoint. if (ld.Min() > rd.Max() || ld.Max() < rd.Min()) { state = 1; } @@ -2401,7 +2589,7 @@ bool Presolver::PresolveTableInt(Constraint* ct, std::string* log) { const int num_tuples = ct->arguments[1].values.size() / num_vars; int ignored_tuples = 0; std::vector new_tuples; - std::vector> visited_values(num_vars); + std::vector> next_values(num_vars); for (int t = 0; t < num_tuples; ++t) { std::vector tuple( ct->arguments[1].values.begin() + t * num_vars, @@ -2415,7 +2603,7 @@ bool Presolver::PresolveTableInt(Constraint* ct, std::string* log) { } if (valid) { for (int i = 0; i < num_vars; ++i) { - visited_values[i].insert(tuple[i]); + next_values[i].insert(tuple[i]); } new_tuples.insert(new_tuples.end(), tuple.begin(), tuple.end()); } else { @@ -2437,8 +2625,8 @@ bool Presolver::PresolveTableInt(Constraint* ct, std::string* log) { const int vmin = var->domain.values.empty() ? 0 : var->domain.values.front(); const int vmax = var->domain.values.empty() ? 0 : var->domain.values.back(); - std::vector values(visited_values[var_index].begin(), - visited_values[var_index].end()); + std::vector values(next_values[var_index].begin(), + next_values[var_index].end()); // TODO(user): Add return value that indicates change to IntersectXXX(). var->domain.IntersectWithListOfIntegers(values); variable_changed |= is_interval != var->domain.is_interval || @@ -2449,6 +2637,124 @@ bool Presolver::PresolveTableInt(Constraint* ct, std::string* log) { return variable_changed || ignored_tuples > 0; } +bool Presolver::PresolveRegular(Constraint* ct, std::string* log) { + const std::vector vars = ct->arguments[0].variables; + if (vars.empty()) { + // TODO(user): presolve when all variables are instantiated. + return false; + } + const int num_vars = vars.size(); + + const int64 num_states = ct->arguments[1].Value(); + const int64 num_values = ct->arguments[2].Value(); + + // Read transitions. + const std::vector& array_transitions = ct->arguments[3].values; + std::vector> automata; + int count = 0; + for (int i = 1; i <= num_states; ++i) { + for (int j = 1; j <= num_values; ++j) { + automata.push_back({i, j, array_transitions[count++]}); + } + } + + const int64 initial_state = ct->arguments[4].Value(); + + std::unordered_set final_states; + switch (ct->arguments[5].type) { + case fz::Argument::INT_VALUE: { + final_states.insert(ct->arguments[5].values[0]); + break; + } + case fz::Argument::INT_INTERVAL: { + for (int64 v = ct->arguments[5].values[0]; + v <= ct->arguments[5].values[1]; ++v) { + final_states.insert(v); + } + break; + } + case fz::Argument::INT_LIST: { + for (int64 value : ct->arguments[5].values) { + final_states.insert(value); + } + break; + } + default: { LOG(FATAL) << "Wrong constraint " << ct->DebugString(); } + } + + // Compute the set of reachable state at each time point. + std::vector> reachable_states(num_vars + 1); + reachable_states[0].insert(initial_state); + reachable_states[num_vars] = {final_states.begin(), final_states.end()}; + + // Forward. + for (int time = 0; time + 1 < num_vars; ++time) { + const Domain& domain = vars[time]->domain; + for (const std::vector& transition : automata) { + if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!domain.Contains(transition[1])) continue; + reachable_states[time + 1].insert(transition[2]); + } + } + + // Backward. + for (int time = num_vars - 1; time > 0; --time) { + std::unordered_set new_set; + const Domain& domain = vars[time]->domain; + for (const std::vector& transition : automata) { + if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!domain.Contains(transition[1])) continue; + if (!ContainsKey(reachable_states[time + 1], transition[2])) continue; + new_set.insert(transition[0]); + } + reachable_states[time].swap(new_set); + } + + // Prune the variables. + bool changed = false; + for (int time = 0; time < num_vars; ++time) { + Domain& domain = vars[time]->domain; + // Collect valid values. + std::unordered_set reached_values; + for (const std::vector& transition : automata) { + if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!domain.Contains(transition[1])) continue; + if (!ContainsKey(reachable_states[time + 1], transition[2])) continue; + reached_values.insert(transition[1]); + } + // Scan domain to check if we will remove values. + std::vector to_keep; + bool remove_some = false; + if (domain.is_interval) { + for (int64 v = domain.values[0]; v <= domain.values[1]; ++v) { + if (ContainsKey(reached_values, v)) { + to_keep.push_back(v); + } else { + remove_some = true; + } + } + } else { + for (int64 v : domain.values) { + if (ContainsKey(reached_values, v)) { + to_keep.push_back(v); + } else { + remove_some = true; + } + } + } + if (remove_some) { + const std::string& before = HASVLOG ? vars[time]->DebugString() : ""; + domain.IntersectWithListOfIntegers(to_keep); + if (HASVLOG) { + StringAppendF(log, "reduce domain of variable %d from %s to %s; ", time, + before.c_str(), vars[time]->DebugString().c_str()); + } + changed = true; + } + } + return changed; +} + // Tranforms diffn into all_different_int when sizes and y positions are all 1. // // Input : diffn([x1, .. xn], [1, .., 1], [1, .., 1], [1, .., 1]) @@ -2552,6 +2858,10 @@ bool Presolver::PresolveOneConstraint(Constraint* ct) { CALL_TYPE(ct, "int_lin_eq", PropagatePositiveLinear); CALL_TYPE(ct, "int_lin_le", PropagatePositiveLinear); CALL_TYPE(ct, "int_lin_ge", PropagatePositiveLinear); + CALL_TYPE(ct, "int_lin_eq", SimplifyPositiveLinear); + CALL_TYPE(ct, "int_lin_ne", SimplifyPositiveLinear); + CALL_TYPE(ct, "int_lin_eq_reif", SimplifyPositiveLinear); + CALL_TYPE(ct, "int_lin_ne_reif", SimplifyPositiveLinear); CALL_TYPE(ct, "int_lin_eq", CreateLinearTarget); CALL_TYPE(ct, "int_lin_eq", PresolveStoreMapping); CALL_TYPE(ct, "int_lin_eq_reif", CheckIntLinReifBounds); @@ -2602,6 +2912,7 @@ bool Presolver::PresolveOneConstraint(Constraint* ct) { } CALL_TYPE(ct, "table_int", PresolveTableInt); CALL_TYPE(ct, "diffn", PresolveDiffN); + CALL_TYPE(ct, "regular", PresolveRegular); // Last rule: if the target variable of a constraint is fixed, removed it // the target part. diff --git a/src/flatzinc/presolve.h b/src/flatzinc/presolve.h index 7809efe9da..9c102a08bb 100644 --- a/src/flatzinc/presolve.h +++ b/src/flatzinc/presolve.h @@ -142,6 +142,7 @@ class Presolver { bool PresolveLinear(Constraint* ct, std::string* log); bool RegroupLinear(Constraint* ct, std::string* log); bool PropagatePositiveLinear(Constraint* ct, std::string* log); + bool SimplifyPositiveLinear(Constraint* ct, std::string* log); bool PresolveStoreMapping(Constraint* ct, std::string* log); bool PresolveSimplifyElement(Constraint* ct, std::string* log); bool PresolveSimplifyExprElement(Constraint* ct, std::string* log); @@ -159,6 +160,7 @@ class Presolver { bool StoreIntEqReif(Constraint* ct, std::string* log); bool SimplifyIntNeReif(Constraint* ct, std::string* log); bool PresolveTableInt(Constraint* ct, std::string* log); + bool PresolveRegular(Constraint* ct, std::string* log); bool PresolveDiffN(Constraint* ct, std::string* log); // Helpers. diff --git a/src/flatzinc/sat_fz_solver.cc b/src/flatzinc/sat_fz_solver.cc index 2766e9ad66..89a963a1a7 100644 --- a/src/flatzinc/sat_fz_solver.cc +++ b/src/flatzinc/sat_fz_solver.cc @@ -462,17 +462,15 @@ void ExtractIntLinNeReif(const fz::Constraint& ct, SatModel* m) { m->model.Add(FixedWeightedSumReif(r.Negated(), vars, coeffs, rhs)); } -// r => (a == cte) +// r => (a == cte). void ImpliesEqualityToConstant(bool reverse_implication, IntegerVariable a, int64 cte, Literal r, SatModel* m) { if (m->model.Get(IsFixed(a))) { if (m->model.Get(Value(a)) == IntegerValue(cte)) { if (reverse_implication) { - FZLOG << "Case could have been presolved." << FZENDL; m->model.GetOrCreate()->AddUnitClause(r); } } else { - FZLOG << "Case could have been presolved." << FZENDL; m->model.GetOrCreate()->AddUnitClause(r.Negated()); } return; @@ -511,7 +509,6 @@ void ImpliesEqualityToConstant(bool reverse_implication, IntegerVariable a, } // Value is not found, the literal must be false. - FZLOG << "Case could have been presolved." << FZENDL; m->model.GetOrCreate()->AddUnitClause(r.Negated()); } @@ -936,7 +933,7 @@ bool ExtractConstraint(const fz::Constraint& ct, SatModel* m) { ExtractArrayVarIntElement(ct, m); } else if (ct.type == "all_different_int") { ExtractAllDifferentInt(ct, m); - } else if (ct.type == "int_eq") { + } else if (ct.type == "int_eq" || ct.type == "bool2int") { ExtractIntEq(ct, m); } else if (ct.type == "int_ne") { ExtractIntNe(ct, m); @@ -990,6 +987,8 @@ bool ExtractConstraint(const fz::Constraint& ct, SatModel* m) { ct.type == "variable_cumulative" || ct.type == "fixed_cumulative") { ExtractCumulative(ct, m); + } else if (ct.type == "false_constraint") { + m->model.GetOrCreate()->NotifyThatModelIsUnsat(); } else { return false; } @@ -1147,12 +1146,23 @@ void SolveWithSat(const fz::Model& fz_model, const fz::FlatzincParameters& p, FZLOG << "Extracting " << fz_model.constraints().size() << " constraints. " << FZENDL; std::set unsupported_types; + Trail* trail = m.model.GetOrCreate(); for (fz::Constraint* ct : fz_model.constraints()) { if (ct != nullptr && ct->active) { + const int old_num_fixed = trail->Index(); FZVLOG << "Extracting '" << ct->type << "'." << FZENDL; if (!ExtractConstraint(*ct, &m)) { unsupported_types.insert(ct->type); } + + // We propagate after each new Boolean constraint but not the integer + // ones. So we call Propagate() manually here. TODO(user): Do that + // automatically? + m.model.GetOrCreate()->Propagate(); + if (trail->Index() > old_num_fixed) { + FZVLOG << "Constraint fixed " << trail->Index() - old_num_fixed + << " Boolean variable(s): " << ct->DebugString() << FZENDL; + } if (m.model.GetOrCreate()->IsModelUnsat()) { FZLOG << "UNSAT during extraction (after adding '" << ct->type << "')." << FZENDL; @@ -1169,21 +1179,30 @@ void SolveWithSat(const fz::Model& fz_model, const fz::FlatzincParameters& p, } // Some stats. - int num_fully_encoded_variables = 0; - for (int i = 0; i < m.model.Get()->NumIntegerVariables(); ++i) { - if (m.model.Get()->VariableIsFullyEncoded( - IntegerVariable(i))) { - ++num_fully_encoded_variables; + { + int num_fully_encoded_variables = 0; + for (int i = 0; i < m.model.Get()->NumIntegerVariables(); + ++i) { + if (m.model.Get()->VariableIsFullyEncoded( + IntegerVariable(i))) { + ++num_fully_encoded_variables; + } } + // We divide by two because of the automatically created NegationOf() var. + FZLOG << "Num integer variables = " + << m.model.GetOrCreate()->NumIntegerVariables() / 2 + << FZENDL; + FZLOG << "Num fully encoded variable = " << num_fully_encoded_variables / 2 + << FZENDL; + FZLOG << "Num initial SAT variables = " + << m.model.Get()->NumVariables() << " (" + << m.model.Get()->LiteralTrail().Index() << " fixed)." + << FZENDL; + FZLOG << "Num constants = " << m.constant_map.size() << FZENDL; + FZLOG << "Num integer propagators = " + << m.model.GetOrCreate()->NumPropagators() + << FZENDL; } - // We divide by two because of the automatically created NegationOf() var. - FZLOG << "Num integer variables = " - << m.model.Get()->NumIntegerVariables() / 2 << FZENDL; - FZLOG << "Num fully encoded variable = " << num_fully_encoded_variables / 2 - << FZENDL; - FZLOG << "Num Boolean variables created = " - << m.model.Get()->NumVariables() << FZENDL; - FZLOG << "Num constants = " << m.constant_map.size() << FZENDL; int num_solutions = 0; int64 best_objective = 0; diff --git a/src/sat/integer.h b/src/sat/integer.h index b4d8230b9e..22ecb6d263 100644 --- a/src/sat/integer.h +++ b/src/sat/integer.h @@ -660,6 +660,9 @@ class GenericLiteralWatcher : public SatPropagator { // is usually done in a CP solver at the cost of a sligthly more complex API. void RegisterReversibleInt(int id, int* rev); + // Returns the number of registered propagators. + int NumPropagators() const { return in_queue_.size(); } + private: // Updates queue_ and in_queue_ with the propagator ids that need to be // called. diff --git a/src/sat/integer_expr.cc b/src/sat/integer_expr.cc index b92d13b498..1597b45d73 100644 --- a/src/sat/integer_expr.cc +++ b/src/sat/integer_expr.cc @@ -27,6 +27,9 @@ IntegerSumLE::IntegerSumLE(LiteralIndex reified_literal, coeffs_(coeffs), trail_(trail), integer_trail_(integer_trail) { + // TODO(user): deal with this corner case. + CHECK(!vars_.empty()); + // Handle negative coefficients. for (int i = 0; i < vars.size(); ++i) { if (coeffs_[i] < 0) { @@ -57,8 +60,6 @@ void IntegerSumLE::FillIntegerReason() { } bool IntegerSumLE::Propagate() { - CHECK(!vars_.empty()); // TODO(user): deal with this corner case. - // Reified case: If the reified literal is false, we ignore the constraint. if (reified_literal_ != kNoLiteralIndex && trail_->Assignment().LiteralIsFalse(Literal(reified_literal_))) { @@ -71,7 +72,8 @@ bool IntegerSumLE::Propagate() { } // Conflict? - if (new_lb > upper_bound_) { + const IntegerValue slack = upper_bound_ - new_lb; + if (slack < 0) { FillIntegerReason(); // Reified case: If the reified literal is unassigned, we set it to false, @@ -92,31 +94,34 @@ bool IntegerSumLE::Propagate() { return true; } - // Will only be filled on the first push. - integer_reason_.clear(); + // The integer_reason_ will only be filled on the first push. + bool first_push = true; // The lower bound of all the variables minus one can be used to update the // upper bound of the last one. for (int i = 0; i < vars_.size(); ++i) { - const IntegerValue new_lb_excluding_i = - new_lb - integer_trail_->LowerBound(vars_[i]) * coeffs_[i]; - const IntegerValue new_term_ub = upper_bound_ - new_lb_excluding_i; - const IntegerValue new_ub = new_term_ub / coeffs_[i]; - if (new_ub < integer_trail_->UpperBound(vars_[i])) { - if (integer_reason_.empty()) FillIntegerReason(); + const IntegerVariable var = vars_[i]; + const IntegerValue var_slack = + integer_trail_->UpperBound(var) - integer_trail_->LowerBound(var); + if (var_slack * coeffs_[i] > slack) { + if (first_push) { + first_push = false; + FillIntegerReason(); + } // We need to remove the entry index from the reason temporarily. IntegerLiteral saved; const int index = index_in_integer_reason_[i]; if (index >= 0) { saved = integer_reason_[index]; - std::swap(integer_reason_[index], integer_reason_.back()); + integer_reason_[index] = integer_reason_.back(); integer_reason_.pop_back(); } - if (!integer_trail_->Enqueue( - IntegerLiteral::LowerOrEqual(vars_[i], new_ub), literal_reason_, - integer_reason_)) { + const IntegerValue new_ub = + integer_trail_->LowerBound(var) + slack / coeffs_[i]; + if (!integer_trail_->Enqueue(IntegerLiteral::LowerOrEqual(var, new_ub), + literal_reason_, integer_reason_)) { return false; } diff --git a/src/sat/integer_expr.h b/src/sat/integer_expr.h index 690dd3fe38..90bc0606fe 100644 --- a/src/sat/integer_expr.h +++ b/src/sat/integer_expr.h @@ -296,6 +296,7 @@ inline std::function WeightedSumGreaterOrEqualReif( } // Weighted sum == constant reified. +// TODO(user): Simplify if the constant is at the edge of the possible values. template inline std::function FixedWeightedSumReif( Literal is_eq, const std::vector& vars, @@ -312,6 +313,7 @@ inline std::function FixedWeightedSumReif( } // Weighted sum != constant. +// TODO(user): Simplify if the constant is at the edge of the possible values. template inline std::function WeightedSumNotEqual( const std::vector& vars, const VectorInt& coefficients,