add new API on sat solver

This commit is contained in:
Laurent Perron
2016-10-12 12:07:20 -07:00
parent 05cbf0de00
commit e4a2d9f2df
9 changed files with 317 additions and 101 deletions

View File

@@ -23,6 +23,7 @@
#include "base/hash.h"
#include "algorithms/find_graph_symmetries.h"
#include "graph/graph.h"
#include "graph/io.h"
#include "graph/util.h"
DEFINE_string(debug_dump_symmetry_graph_to_file, "",

View File

@@ -18,6 +18,60 @@
namespace operations_research {
namespace sat {
bool BooleanXorPropagator::Propagate(Trail* trail) {
bool sum = false;
int unassigned_index = -1;
for (int i = 0; i < literals_.size(); ++i) {
const Literal l = literals_[i];
if (trail->Assignment().LiteralIsFalse(l)) {
sum ^= false;
} else if (trail->Assignment().LiteralIsTrue(l)) {
sum ^= true;
} else {
// If we have more than one unassigned literal, we can't deduce anything.
if (unassigned_index != -1) return true;
unassigned_index = i;
}
}
// Propagates?
if (unassigned_index != -1) {
std::vector<Literal>* literal_reason;
std::vector<IntegerLiteral>* integer_reason;
const Literal u = literals_[unassigned_index];
integer_trail_->EnqueueLiteral(sum == value_ ? u.Negated() : u,
&literal_reason, &integer_reason);
for (int i = 0; i < literals_.size(); ++i) {
if (i == unassigned_index) continue;
const Literal l = literals_[i];
literal_reason->push_back(
trail->Assignment().LiteralIsFalse(l) ? l : l.Negated());
}
return true;
}
// Ok.
if (sum == value_) return true;
// Conflict.
std::vector<Literal>* conflict = trail->MutableConflict();
conflict->clear();
for (int i = 0; i < literals_.size(); ++i) {
const Literal l = literals_[i];
conflict->push_back(trail->Assignment().LiteralIsFalse(l) ? l
: l.Negated());
}
return false;
}
void BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
const int id = watcher->Register(this);
for (const Literal& l : literals_) {
watcher->WatchLiteral(l, id);
watcher->WatchLiteral(l.Negated(), id);
}
}
std::function<void(Model*)> AllDifferent(const std::vector<IntegerVariable>& vars) {
return [=](Model* model) {
hash_set<IntegerValue> fixed_values;

View File

@@ -20,9 +20,47 @@
namespace operations_research {
namespace sat {
// Propagate the fact that a XOR of literals is equal to the given value.
// The complexity is in O(n).
//
// TODO(user): By using a two watcher mechanism, we can propagate this a lot
// faster.
class BooleanXorPropagator : public PropagatorInterface {
public:
BooleanXorPropagator(const std::vector<Literal>& literals, bool value,
IntegerTrail* integer_trail)
: literals_(literals), value_(value), integer_trail_(integer_trail) {}
bool Propagate(Trail* trail) final;
void RegisterWith(GenericLiteralWatcher* watcher);
private:
const std::vector<Literal> literals_;
const bool value_;
IntegerTrail* integer_trail_;
DISALLOW_COPY_AND_ASSIGN(BooleanXorPropagator);
};
// ============================================================================
// Model based functions.
// ============================================================================
// Enforces that the given tuple of variables takes different values.
std::function<void(Model*)> AllDifferent(const std::vector<IntegerVariable>& vars);
// Enforces the XOR of a set of literals to be equal to the given value.
inline std::function<void(Model*)> LiteralXorIs(const std::vector<Literal>& literals,
bool value) {
return [=](Model* model) {
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
BooleanXorPropagator* constraint =
new BooleanXorPropagator(literals, value, integer_trail);
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
};
}
} // namespace sat
} // namespace operations_research

View File

@@ -32,16 +32,22 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var,
// the caller to deal with this case.
CHECK_NE(values.size(), 1);
// If the variable has already been fully encoded, for now we check that
// the sets of value is the same.
//
// TODO(user): Take the intersection, and handle that case in the constraints
// creation functions.
// If the variable has already been fully encoded, we set to false the
// literals that cannot be true anymore. We also log a warning because ideally
// these intersection should happen in the presolve.
if (ContainsKey(full_encoding_index_, i_var)) {
int num_fixed = 0;
hash_set<IntegerValue> to_interset(values.begin(), values.end());
const std::vector<ValueLiteralPair>& encoding = FullDomainEncoding(i_var);
CHECK_EQ(values.size(), encoding.size());
for (int i = 0; i < values.size(); ++i) {
CHECK_EQ(values[i], encoding[i].value);
for (const ValueLiteralPair& p : encoding) {
if (!ContainsKey(to_interset, p.value)) {
// TODO(user): also remove this entry from encoding.
++num_fixed;
sat_solver_->AddUnitClause(p.literal.Negated());
}
}
if (num_fixed > 0) {
LOG(WARNING) << "Domain intersection removed " << num_fixed << " values.";
}
return;
}

View File

@@ -186,9 +186,11 @@ class IntegerEncoder {
// 3/ The encoding for NegationOf(var) is automatically created too. It reuses
// the same Boolean variable as the encoding of var.
//
// Calling this more than once is an error (Checked).
// TODO(user): we could instead only keep the intersection and fix the now
// impossible values to zero.
// Calling this more than once will take the intersection of all the given
// values arguments. However, this is not optimal because the first calls may
// creates new Boolean variables that will later be fixed, so we log a warning
// when this happen. Ideally, the intersection should be done in a presolve
// step to be as efficient as possible here.
//
// Note(user): There is currently no relation here between
// FullyEncodeVariable() and CreateAssociatedLiteral(). However the
@@ -730,6 +732,14 @@ inline std::function<void(Model*)> LowerOrEqual(IntegerVariable v, int64 ub) {
};
}
// Fix v to a given value.
inline std::function<void(Model*)> Equality(IntegerVariable v, int64 value) {
return [=](Model* model) {
model->Add(LowerOrEqual(v, value));
model->Add(GreaterOrEqual(v, value));
};
}
// Associate the given literal to the given integer inequality.
inline std::function<void(Model*)> Equality(IntegerLiteral i, Literal l) {
return [=](Model* model) {

View File

@@ -198,6 +198,23 @@ template <typename VectorInt>
inline std::function<void(Model*)> WeightedSumLowerOrEqual(
const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
int64 upper_bound) {
// Special cases.
// TODO(user): Do the same for the reified case.
CHECK_GE(vars.size(), 1) << "Should be encoded differently.";
if (vars.size() == 2 && (coefficients[0] == 1 || coefficients[0] == -1) &&
(coefficients[1] == 1 || coefficients[1] == -1)) {
return Sum2LowerOrEqual(
coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]), upper_bound);
}
if (vars.size() == 3 && (coefficients[0] == 1 || coefficients[0] == -1) &&
(coefficients[1] == 1 || coefficients[1] == -1) &&
(coefficients[2] == 1 || coefficients[2] == -1)) {
return Sum3LowerOrEqual(
coefficients[0] == 1 ? vars[0] : NegationOf(vars[0]),
coefficients[1] == 1 ? vars[1] : NegationOf(vars[1]),
coefficients[2] == 1 ? vars[2] : NegationOf(vars[2]), upper_bound);
}
return [=](Model* model) {
IntegerSumLE* constraint = new IntegerSumLE(
kNoLiteralIndex, vars,
@@ -213,17 +230,10 @@ template <typename VectorInt>
inline std::function<void(Model*)> WeightedSumGreaterOrEqual(
const std::vector<IntegerVariable>& vars, const VectorInt& coefficients,
int64 lower_bound) {
return [=](Model* model) {
// We just negate everything and use an IntegerSumLE() constraints.
std::vector<IntegerValue> negated_coeffs(coefficients.begin(),
coefficients.end());
for (IntegerValue& ref : negated_coeffs) ref = -ref;
IntegerSumLE* constraint = new IntegerSumLE(
kNoLiteralIndex, vars, negated_coeffs, IntegerValue(-lower_bound),
model->GetOrCreate<IntegerTrail>());
constraint->RegisterWith(model->GetOrCreate<GenericLiteralWatcher>());
model->TakeOwnership(constraint);
};
// We just negate everything and use an IntegerSumLE() constraints.
std::vector<IntegerValue> negated_coeffs(coefficients.begin(), coefficients.end());
for (IntegerValue& ref : negated_coeffs) ref = -ref;
return WeightedSumLowerOrEqual(vars, negated_coeffs, -lower_bound);
}
// Weighted sum == constant.

View File

@@ -70,9 +70,6 @@ bool PrecedencesPropagator::Propagate(Trail* trail) {
for (const int arc_index : potential_arcs_[literal.Index()]) {
const ArcInfo& arc = arcs_[arc_index];
impacted_arcs_[arc.tail_var].push_back(arc_index);
if (arc.offset_var != kNoIntegerVariable) {
impacted_arcs_[arc.offset_var].push_back(arc_index);
}
}
// Iterate again to check for a propagation and indirectly update
@@ -85,7 +82,6 @@ bool PrecedencesPropagator::Propagate(Trail* trail) {
if (new_head_lb > integer_trail_->LowerBound(arc.head_var)) {
if (!EnqueueAndCheck(arc, new_head_lb, trail)) return false;
}
PropagateMaxOffsetIfNeeded(arc, trail);
}
}
@@ -115,9 +111,6 @@ void PrecedencesPropagator::Untrail(const Trail& trail, int trail_index) {
for (const int arc_index : potential_arcs_[literal.Index()]) {
const ArcInfo& arc = arcs_[arc_index];
impacted_arcs_[arc.tail_var].pop_back();
if (arc.offset_var != kNoIntegerVariable) {
impacted_arcs_[arc.offset_var].pop_back();
}
}
}
}
@@ -234,35 +227,59 @@ void PrecedencesPropagator::AddArc(IntegerVariable tail, IntegerVariable head,
AdjustSizeFor(head);
if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var);
for (const bool forward : {true, false}) {
const IntegerVariable tail_var = forward ? tail : NegationOf(head);
const IntegerVariable head_var = forward ? head : NegationOf(tail);
// Handle level zero stuff.
DCHECK_EQ(trail_->CurrentDecisionLevel(), 0);
if (l != kNoLiteralIndex) {
// Check if the conditional literal is already fixed.
if (trail_->Assignment().LiteralIsTrue(Literal(l))) {
l = kNoLiteralIndex; // At true, same as fixed arc.
} else if (trail_->Assignment().LiteralIsFalse(Literal(l))) {
return; // At false, nothing to add.
}
}
if (l != kNoLiteralIndex && l.value() >= potential_arcs_.size()) {
potential_arcs_.resize(l.value() + 1);
}
struct InternalArc {
IntegerVariable tail_var;
IntegerVariable head_var;
IntegerVariable offset_var;
};
std::vector<InternalArc> to_add;
if (offset_var == kNoIntegerVariable) {
// a + offset <= b and -b + offset <= -a
to_add.push_back({tail, head, kNoIntegerVariable});
to_add.push_back({NegationOf(head), NegationOf(tail), kNoIntegerVariable});
} else {
// tail (a) and offset_var (b) are symmetric, so we add:
// - a + b + offset <= c
to_add.push_back({tail, head, offset_var});
to_add.push_back({offset_var, head, tail});
// - a - c + offset <= -b
to_add.push_back({tail, NegationOf(offset_var), NegationOf(head)});
to_add.push_back({NegationOf(head), NegationOf(offset_var), tail});
// - b - c + offset <= -a
to_add.push_back({offset_var, NegationOf(tail), NegationOf(head)});
to_add.push_back({NegationOf(head), NegationOf(tail), offset_var});
}
for (const InternalArc a : to_add) {
// Since we add a new arc, we will need to consider its tail during the next
// propagation.
//
// TODO(user): Adding arcs and then calling Untrail() before Propagate()
// will cause this mecanism to break. Find a more robust implementation.
modified_vars_.Set(tail_var);
modified_vars_.Set(a.tail_var);
const int arc_index = arcs_.size();
if (l == kNoLiteralIndex) {
impacted_arcs_[tail_var].push_back(arc_index);
if (offset_var != kNoIntegerVariable) {
impacted_arcs_[offset_var].push_back(arc_index);
}
impacted_arcs_[a.tail_var].push_back(arc_index);
} else {
impacted_potential_arcs_[tail_var].push_back(arc_index);
if (offset_var != kNoIntegerVariable) {
impacted_potential_arcs_[offset_var].push_back(arc_index);
}
if (l.value() >= potential_arcs_.size()) {
potential_arcs_.resize(l.value() + 1);
}
impacted_potential_arcs_[a.tail_var].push_back(arc_index);
potential_arcs_[l].push_back(arc_index);
}
arcs_.push_back({tail_var, head_var, offset, offset_var, l});
arcs_.push_back({a.tail_var, a.head_var, offset, a.offset_var, l});
}
}
@@ -364,30 +381,6 @@ bool PrecedencesPropagator::EnqueueAndCheck(const ArcInfo& arc,
literal_reason_, integer_reason_);
}
bool PrecedencesPropagator::PropagateMaxOffsetIfNeeded(const ArcInfo& arc,
Trail* trail) {
if (arc.offset_var == kNoIntegerVariable) return false;
if (!IsInvalidOrTrue(OptionalLiteralOf(arc.head_var), *trail)) return false;
const IntegerValue max_duration =
CapSub(integer_trail_->UpperBound(arc.head_var),
integer_trail_->LowerBound(arc.tail_var));
if (max_duration < integer_trail_->UpperBound(arc.offset_var)) {
literal_reason_.clear();
AppendNegationIfValid(arc.presence_l, &literal_reason_);
AppendNegationIfValid(OptionalLiteralOf(arc.tail_var), &literal_reason_);
integer_reason_.clear();
integer_reason_.push_back(
integer_trail_->LowerBoundAsLiteral(arc.tail_var));
integer_reason_.push_back(
integer_trail_->UpperBoundAsLiteral(arc.head_var));
return integer_trail_->Enqueue(
IntegerLiteral::LowerOrEqual(arc.offset_var, max_duration),
literal_reason_, integer_reason_);
}
return false;
}
bool PrecedencesPropagator::NoPropagationLeft(const Trail& trail) const {
const int num_nodes = impacted_arcs_.size();
for (IntegerVariable var(0); var < num_nodes; ++var) {
@@ -586,17 +579,6 @@ bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) {
const ArcInfo& arc = arcs_[arc_index];
if (!ArcShouldPropagate(arc, *trail)) continue;
if (PropagateMaxOffsetIfNeeded(arc, trail)) {
const IntegerVariable minus_offset = NegationOf(arc.offset_var);
// TODO(user): We currently don't have any cycle detection here.
bf_can_be_skipped_[minus_offset.value()] = false;
if (!bf_in_queue_[minus_offset.value()]) {
bf_queue_.push_back(minus_offset.value());
bf_in_queue_[minus_offset.value()] = true;
}
}
const IntegerValue candidate =
CapAdd(integer_trail_->LowerBound(arc.tail_var), ArcOffset(arc));
if (candidate > integer_trail_->LowerBound(arc.head_var)) {
@@ -614,6 +596,13 @@ bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) {
return false;
}
// We need to enforce the invariant that only the arc_index in
// bf_parent_arc_of_[] are marked (but not necessarily all of them
// since we unmark some in DisassembleSubtree()).
if (bf_parent_arc_of_[arc.head_var.value()] != -1) {
arcs_[bf_parent_arc_of_[arc.head_var.value()]].is_marked = false;
}
// Tricky: We just enqueued the fact that the lower-bound of head is
// candidate. However, because the domain of head may be discrete, it is
// possible that the lower-bound of head is now higher than candidate!
@@ -621,14 +610,12 @@ bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) {
// don't wrongly detect a positive weight cycle because of this "extra
// push".
if (integer_trail_->LowerBound(arc.head_var) == candidate) {
// We need to enforce the invariant that only the arc_index in
// bf_parent_arc_of_[] are marked (but not necessarily all of them
// since we unmark some in DisassembleSubtree()).
if (bf_parent_arc_of_[arc.head_var.value()] != -1) {
arcs_[bf_parent_arc_of_[arc.head_var.value()]].is_marked = false;
}
bf_parent_arc_of_[arc.head_var.value()] = arc_index;
arcs_[arc_index].is_marked = true;
} else {
// We still unmark any previous dependency, since we have pushed the
// value of arc.head_var further.
bf_parent_arc_of_[arc.head_var.value()] = -1;
}
bf_can_be_skipped_[arc.head_var.value()] = false;

View File

@@ -39,9 +39,10 @@ namespace sat {
// Another word is "separation logic".
class PrecedencesPropagator : public Propagator {
public:
PrecedencesPropagator(IntegerTrail* integer_trail,
PrecedencesPropagator(Trail* trail, IntegerTrail* integer_trail,
GenericLiteralWatcher* watcher)
: Propagator("PrecedencesPropagator"),
trail_(trail),
integer_trail_(integer_trail),
watcher_(watcher),
watcher_id_(watcher->Register(this)) {
@@ -49,9 +50,9 @@ class PrecedencesPropagator : public Propagator {
}
static PrecedencesPropagator* CreateInModel(Model* model) {
PrecedencesPropagator* precedences =
new PrecedencesPropagator(model->GetOrCreate<IntegerTrail>(),
model->GetOrCreate<GenericLiteralWatcher>());
PrecedencesPropagator* precedences = new PrecedencesPropagator(
model->GetOrCreate<Trail>(), model->GetOrCreate<IntegerTrail>(),
model->GetOrCreate<GenericLiteralWatcher>());
// TODO(user): Find a way to have more control on the order in which
// the propagators are added.
@@ -86,6 +87,10 @@ class PrecedencesPropagator : public Propagator {
// when I wrote this, I just had a couple of problems to test this on.
void AddPrecedenceWithVariableOffset(IntegerVariable i1, IntegerVariable i2,
IntegerVariable offset_var);
void AddPrecedenceWithVariableAndFixedOffset(IntegerVariable i1,
IntegerVariable i2,
IntegerValue offset,
IntegerVariable offset_var);
// An optional integer variable has a special behavior:
// - If the bounds on i cross each other, then is_present must be false.
@@ -165,7 +170,6 @@ class PrecedencesPropagator : public Propagator {
// from the current value of arc.tail_lb and the offset of this arc.
bool EnqueueAndCheck(const ArcInfo& arc, IntegerValue new_head_lb,
Trail* trail);
bool PropagateMaxOffsetIfNeeded(const ArcInfo& arc, Trail* trail);
IntegerValue ArcOffset(const ArcInfo& arc) const;
// Returns true iff this arc should propagate. For now, this is true when:
@@ -207,6 +211,7 @@ class PrecedencesPropagator : public Propagator {
// External class needed to get the IntegerVariable lower bounds and Enqueue
// new ones.
Trail* trail_;
IntegerTrail* integer_trail_;
GenericLiteralWatcher* watcher_;
int watcher_id_;
@@ -307,6 +312,12 @@ inline void PrecedencesPropagator::AddPrecedenceWithVariableOffset(
AddArc(i1, i2, /*offset=*/IntegerValue(0), offset_var, /*l=*/kNoLiteralIndex);
}
inline void PrecedencesPropagator::AddPrecedenceWithVariableAndFixedOffset(
IntegerVariable i1, IntegerVariable i2, IntegerValue offset,
IntegerVariable offset_var) {
AddArc(i1, i2, offset, offset_var, /*l=*/kNoLiteralIndex);
}
// =============================================================================
// Model based functions.
// =============================================================================
@@ -329,6 +340,25 @@ inline std::function<void(Model*)> LowerOrEqualWithOffset(IntegerVariable a,
};
}
// a + b <= ub.
inline std::function<void(Model*)> Sum2LowerOrEqual(IntegerVariable a,
IntegerVariable b,
int64 ub) {
return LowerOrEqualWithOffset(a, NegationOf(b), -ub);
}
// a + b + c <= ub.
inline std::function<void(Model*)> Sum3LowerOrEqual(IntegerVariable a,
IntegerVariable b,
IntegerVariable c,
int64 ub) {
return [=](Model* model) {
return model->GetOrCreate<PrecedencesPropagator>()
->AddPrecedenceWithVariableAndFixedOffset(a, NegationOf(c),
IntegerValue(-ub), b);
};
}
// a >= b.
inline std::function<void(Model*)> GreaterOrEqual(IntegerVariable a,
IntegerVariable b) {
@@ -387,6 +417,19 @@ inline std::function<void(Model*)> ReifiedEquality(IntegerVariable a,
};
}
// a != b.
inline std::function<void(Model*)> NotEqual(IntegerVariable a,
IntegerVariable b) {
return [=](Model* model) {
// We model this by is_le and is_ge cannot be both true.
const Literal is_le = Literal(model->Add(NewBooleanVariable()), true);
const Literal is_ge = Literal(model->Add(NewBooleanVariable()), true);
model->Add(Implication(is_le, is_ge.Negated()));
model->Add(ReifiedLowerOrEqualWithOffset(a, b, 0, is_le));
model->Add(ReifiedLowerOrEqualWithOffset(b, a, 0, is_ge));
};
}
} // namespace sat
} // namespace operations_research

View File

@@ -13,6 +13,8 @@
#include "sat/table.h"
#include <unordered_set>
#include "base/map_util.h"
#include "base/stl_util.h"
@@ -39,13 +41,39 @@ std::vector<std::vector<IntegerValue>> Transpose(const std::vector<std::vector<i
// Converts the vector representation returned by FullDomainEncoding() to a map.
hash_map<IntegerValue, Literal> GetEncoding(IntegerVariable var, Model* model) {
hash_map<IntegerValue, Literal> encoding;
for (const auto& entry :
model->GetOrCreate<IntegerEncoder>()->FullDomainEncoding(var)) {
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
for (const auto& entry : encoder->FullDomainEncoding(var)) {
encoding[entry.value] = entry.literal;
}
return encoding;
}
void FilterValues(IntegerVariable var, Model* model,
std::unordered_set<int64>* values) {
const int64 lb = model->Get(LowerBound(var));
const int64 ub = model->Get(UpperBound(var));
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
if (encoder->VariableIsFullyEncoded(var)) {
const auto encoding = GetEncoding(var, model);
for (auto it = values->begin(); it != values->end();) {
const int64 v = *it;
auto copy = it++;
if (v < lb || v > ub || !ContainsKey(encoding, IntegerValue(v))) {
values->erase(copy);
}
}
} else {
for (auto it = values->begin(); it != values->end();) {
const int64 v = *it;
auto copy = it++;
if (v < lb || v > ub) {
values->erase(copy);
}
}
}
}
// Add the implications and clauses to link one column of a table to the Literal
// controling if the lines are possible or not. The column has the given values,
// and the Literal of the column variable can be retreived using the encoding
@@ -80,6 +108,36 @@ void ProcessOneColumn(const std::vector<Literal>& line_literals,
std::function<void(Model*)> TableConstraint(
const std::vector<IntegerVariable>& vars, const std::vector<std::vector<int64>>& tuples) {
return [=](Model* model) {
const int n = vars.size();
// Compute the set of possible values for each variable (from the table).
std::vector<std::unordered_set<int64>> values_per_var(n);
for (const std::vector<int64>& tuple : tuples) {
for (int i = 0; i < n; ++i) {
values_per_var[i].insert(tuple[i]);
}
}
// Filter each values_per_var entries using the current variable domain.
for (int i = 0; i < n; ++i) {
FilterValues(vars[i], model, &values_per_var[i]);
}
// Remove the unreachable tuples.
std::vector<std::vector<int64>> new_tuples;
for (const std::vector<int64>& tuple : tuples) {
bool keep = true;
for (int i = 0; i < n; ++i) {
if (!ContainsKey(values_per_var[i], tuple[i])) {
keep = false;
break;
}
}
if (keep) {
new_tuples.push_back(tuple);
}
}
// Create one Boolean variable per tuple to indicate if it can still be
// selected or not. Note that we don't enforce exactly one tuple to be
// selected because these variables are just used by this constraint, so
@@ -89,18 +147,24 @@ std::function<void(Model*)> TableConstraint(
// new BooleanVariable corresponding to this line since we can use the one
// corresponding to this value in that column.
std::vector<Literal> tuple_literals;
for (int i = 0; i < tuples.size(); ++i) {
for (int i = 0; i < new_tuples.size(); ++i) {
tuple_literals.push_back(Literal(model->Add(NewBooleanVariable()), true));
}
// Fully encode the variables using all the values appearing in the tuples.
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
hash_map<IntegerValue, Literal> encoding;
const std::vector<std::vector<IntegerValue>>& tr_tuples = Transpose(tuples);
for (int i = 0; i < vars.size(); ++i) {
encoder->FullyEncodeVariable(vars[i], tr_tuples[i]);
encoding = GetEncoding(vars[i], model);
ProcessOneColumn(tuple_literals, tr_tuples[i], encoding, model);
const std::vector<std::vector<IntegerValue>> tr_tuples = Transpose(new_tuples);
for (int i = 0; i < n; ++i) {
const IntegerValue first = tr_tuples[i].front();
if (std::all_of(tr_tuples[i].begin(), tr_tuples[i].end(),
[first](IntegerValue v) { return v == first; })) {
model->Add(Equality(vars[i], first.value()));
} else {
encoder->FullyEncodeVariable(vars[i], tr_tuples[i]);
encoding = GetEncoding(vars[i], model);
ProcessOneColumn(tuple_literals, tr_tuples[i], encoding, model);
}
}
};
}
@@ -132,6 +196,9 @@ std::function<void(Model*)> TransitionConstraint(
reachable_states[n] = {final_states.begin(), final_states.end()};
// Forward.
//
// TODO(user): filter using the domain of vars[time] that may not contain
// all the possible transitions.
for (int time = 0; time + 1 < n; ++time) {
for (const std::vector<int64>& transition : automata) {
if (!ContainsKey(reachable_states[time], transition[0])) continue;