minor improvements to sat internals
This commit is contained in:
@@ -47,7 +47,8 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var,
|
||||
}
|
||||
}
|
||||
if (num_fixed > 0) {
|
||||
LOG(WARNING) << "Domain intersection removed " << num_fixed << " values.";
|
||||
LOG(WARNING) << "Domain intersection removed " << num_fixed << " values "
|
||||
<< "(out of " << encoding.size() << ").";
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -135,13 +136,21 @@ void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) {
|
||||
auto after_it = map_ref.lower_bound(i_lit.bound);
|
||||
if (after_it != map_ref.end()) {
|
||||
// Literal(after) => literal
|
||||
sat_solver_->AddBinaryClauseDuringSearch(after_it->second.Negated(),
|
||||
literal);
|
||||
if (sat_solver_->CurrentDecisionLevel() == 0) {
|
||||
sat_solver_->AddBinaryClause(after_it->second.Negated(), literal);
|
||||
} else {
|
||||
sat_solver_->AddBinaryClauseDuringSearch(after_it->second.Negated(),
|
||||
literal);
|
||||
}
|
||||
}
|
||||
if (after_it != map_ref.begin()) {
|
||||
// literal => Literal(before)
|
||||
sat_solver_->AddBinaryClauseDuringSearch(literal.Negated(),
|
||||
(--after_it)->second);
|
||||
if (sat_solver_->CurrentDecisionLevel() == 0) {
|
||||
sat_solver_->AddBinaryClause(literal.Negated(), (--after_it)->second);
|
||||
} else {
|
||||
sat_solver_->AddBinaryClauseDuringSearch(literal.Negated(),
|
||||
(--after_it)->second);
|
||||
}
|
||||
}
|
||||
|
||||
// Add the new entry.
|
||||
|
||||
@@ -862,10 +862,13 @@ inline std::function<void(Model*)> ReifiedInInterval(IntegerVariable v,
|
||||
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
|
||||
const auto lb_lit = IntegerLiteral::GreaterOrEqual(v, IntegerValue(lb));
|
||||
const auto ub_lit = IntegerLiteral::LowerOrEqual(v, IntegerValue(ub));
|
||||
if (lb < model->Get(LowerBound(v))) {
|
||||
CHECK_LT(ub, model->Get(UpperBound(v))) << "Should be presolved.";
|
||||
model->Add(Equality(ub_lit, in_interval));
|
||||
} else if (ub > model->Get(UpperBound(v))) {
|
||||
if (lb <= model->Get(LowerBound(v))) {
|
||||
if (ub >= model->Get(UpperBound(v))) {
|
||||
model->GetOrCreate<SatSolver>()->AddUnitClause(in_interval);
|
||||
} else {
|
||||
model->Add(Equality(ub_lit, in_interval));
|
||||
}
|
||||
} else if (ub >= model->Get(UpperBound(v))) {
|
||||
model->Add(Equality(lb_lit, in_interval));
|
||||
} else {
|
||||
const Literal is_ge_lb = encoder->GetOrCreateAssociatedLiteral(lb_lit);
|
||||
|
||||
@@ -56,6 +56,8 @@ void FilterValues(IntegerVariable var, Model* model,
|
||||
const int64 ub = model->Get(UpperBound(var));
|
||||
|
||||
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
|
||||
const VariablesAssignment& assignment =
|
||||
model->GetOrCreate<Trail>()->Assignment();
|
||||
if (encoder->VariableIsFullyEncoded(var)) {
|
||||
const auto encoding = GetEncoding(var, model);
|
||||
for (auto it = values->begin(); it != values->end();) {
|
||||
@@ -63,6 +65,11 @@ void FilterValues(IntegerVariable var, Model* model,
|
||||
auto copy = it++;
|
||||
if (v < lb || v > ub || !ContainsKey(encoding, IntegerValue(v))) {
|
||||
values->erase(copy);
|
||||
} else {
|
||||
const Literal literal = FindOrDie(encoding, IntegerValue(v));
|
||||
if (assignment.LiteralIsFalse(literal)) {
|
||||
values->erase(copy);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -195,6 +202,28 @@ std::function<void(Model*)> TransitionConstraint(
|
||||
}
|
||||
}
|
||||
|
||||
// Construct a table with the possible values of each vars.
|
||||
std::vector<hash_set<int64>> possible_values(n);
|
||||
const VariablesAssignment& assignment =
|
||||
model->GetOrCreate<Trail>()->Assignment();
|
||||
for (int time = 0; time < n; ++time) {
|
||||
if (encoder->VariableIsFullyEncoded(vars[time])) {
|
||||
for (const auto& entry : encoder->FullDomainEncoding(vars[time])) {
|
||||
if (!assignment.LiteralIsFalse(entry.literal)) {
|
||||
possible_values[time].insert(entry.value.value());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int64 lb = model->Get(LowerBound(vars[time]));
|
||||
const int64 ub = model->Get(UpperBound(vars[time]));
|
||||
for (const std::vector<int64>& transition : automata) {
|
||||
if (lb <= transition[1] && transition[1] <= ub) {
|
||||
possible_values[time].insert(transition[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the set of reachable state at each time point.
|
||||
std::vector<std::set<int64>> reachable_states(n + 1);
|
||||
reachable_states[0].insert(initial_state);
|
||||
@@ -207,6 +236,7 @@ std::function<void(Model*)> TransitionConstraint(
|
||||
for (int time = 0; time + 1 < n; ++time) {
|
||||
for (const std::vector<int64>& transition : automata) {
|
||||
if (!ContainsKey(reachable_states[time], transition[0])) continue;
|
||||
if (!ContainsKey(possible_values[time], transition[1])) continue;
|
||||
reachable_states[time + 1].insert(transition[2]);
|
||||
}
|
||||
}
|
||||
@@ -216,6 +246,7 @@ std::function<void(Model*)> TransitionConstraint(
|
||||
std::set<int64> new_set;
|
||||
for (const std::vector<int64>& transition : automata) {
|
||||
if (!ContainsKey(reachable_states[time], transition[0])) continue;
|
||||
if (!ContainsKey(possible_values[time], transition[1])) continue;
|
||||
if (!ContainsKey(reachable_states[time + 1], transition[2])) continue;
|
||||
new_set.insert(transition[0]);
|
||||
}
|
||||
@@ -240,6 +271,7 @@ std::function<void(Model*)> TransitionConstraint(
|
||||
std::vector<IntegerValue> out_states;
|
||||
for (const std::vector<int64>& transition : automata) {
|
||||
if (!ContainsKey(reachable_states[time], transition[0])) continue;
|
||||
if (!ContainsKey(possible_values[time], transition[1])) continue;
|
||||
if (!ContainsKey(reachable_states[time + 1], transition[2])) continue;
|
||||
|
||||
// TODO(user): if this transition correspond to just one in-state or
|
||||
|
||||
Reference in New Issue
Block a user