diff --git a/src/sat/integer.cc b/src/sat/integer.cc index e04249b7c4..fe69eb1456 100644 --- a/src/sat/integer.cc +++ b/src/sat/integer.cc @@ -47,7 +47,8 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var, } } if (num_fixed > 0) { - LOG(WARNING) << "Domain intersection removed " << num_fixed << " values."; + LOG(WARNING) << "Domain intersection removed " << num_fixed << " values " + << "(out of " << encoding.size() << ")."; } return; } @@ -135,13 +136,21 @@ void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) { auto after_it = map_ref.lower_bound(i_lit.bound); if (after_it != map_ref.end()) { // Literal(after) => literal - sat_solver_->AddBinaryClauseDuringSearch(after_it->second.Negated(), - literal); + if (sat_solver_->CurrentDecisionLevel() == 0) { + sat_solver_->AddBinaryClause(after_it->second.Negated(), literal); + } else { + sat_solver_->AddBinaryClauseDuringSearch(after_it->second.Negated(), + literal); + } } if (after_it != map_ref.begin()) { // literal => Literal(before) - sat_solver_->AddBinaryClauseDuringSearch(literal.Negated(), - (--after_it)->second); + if (sat_solver_->CurrentDecisionLevel() == 0) { + sat_solver_->AddBinaryClause(literal.Negated(), (--after_it)->second); + } else { + sat_solver_->AddBinaryClauseDuringSearch(literal.Negated(), + (--after_it)->second); + } } // Add the new entry. diff --git a/src/sat/integer.h b/src/sat/integer.h index e234697cc2..9b730c68a9 100644 --- a/src/sat/integer.h +++ b/src/sat/integer.h @@ -862,10 +862,13 @@ inline std::function ReifiedInInterval(IntegerVariable v, IntegerEncoder* encoder = model->GetOrCreate(); const auto lb_lit = IntegerLiteral::GreaterOrEqual(v, IntegerValue(lb)); const auto ub_lit = IntegerLiteral::LowerOrEqual(v, IntegerValue(ub)); - if (lb < model->Get(LowerBound(v))) { - CHECK_LT(ub, model->Get(UpperBound(v))) << "Should be presolved."; - model->Add(Equality(ub_lit, in_interval)); - } else if (ub > model->Get(UpperBound(v))) { + if (lb <= model->Get(LowerBound(v))) { + if (ub >= model->Get(UpperBound(v))) { + model->GetOrCreate()->AddUnitClause(in_interval); + } else { + model->Add(Equality(ub_lit, in_interval)); + } + } else if (ub >= model->Get(UpperBound(v))) { model->Add(Equality(lb_lit, in_interval)); } else { const Literal is_ge_lb = encoder->GetOrCreateAssociatedLiteral(lb_lit); diff --git a/src/sat/table.cc b/src/sat/table.cc index 643f522501..fc2f2f933d 100644 --- a/src/sat/table.cc +++ b/src/sat/table.cc @@ -56,6 +56,8 @@ void FilterValues(IntegerVariable var, Model* model, const int64 ub = model->Get(UpperBound(var)); IntegerEncoder* encoder = model->GetOrCreate(); + const VariablesAssignment& assignment = + model->GetOrCreate()->Assignment(); if (encoder->VariableIsFullyEncoded(var)) { const auto encoding = GetEncoding(var, model); for (auto it = values->begin(); it != values->end();) { @@ -63,6 +65,11 @@ void FilterValues(IntegerVariable var, Model* model, auto copy = it++; if (v < lb || v > ub || !ContainsKey(encoding, IntegerValue(v))) { values->erase(copy); + } else { + const Literal literal = FindOrDie(encoding, IntegerValue(v)); + if (assignment.LiteralIsFalse(literal)) { + values->erase(copy); + } } } } else { @@ -195,6 +202,28 @@ std::function TransitionConstraint( } } + // Construct a table with the possible values of each vars. + std::vector> possible_values(n); + const VariablesAssignment& assignment = + model->GetOrCreate()->Assignment(); + for (int time = 0; time < n; ++time) { + if (encoder->VariableIsFullyEncoded(vars[time])) { + for (const auto& entry : encoder->FullDomainEncoding(vars[time])) { + if (!assignment.LiteralIsFalse(entry.literal)) { + possible_values[time].insert(entry.value.value()); + } + } + } else { + const int64 lb = model->Get(LowerBound(vars[time])); + const int64 ub = model->Get(UpperBound(vars[time])); + for (const std::vector& transition : automata) { + if (lb <= transition[1] && transition[1] <= ub) { + possible_values[time].insert(transition[1]); + } + } + } + } + // Compute the set of reachable state at each time point. std::vector> reachable_states(n + 1); reachable_states[0].insert(initial_state); @@ -207,6 +236,7 @@ std::function TransitionConstraint( for (int time = 0; time + 1 < n; ++time) { for (const std::vector& transition : automata) { if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!ContainsKey(possible_values[time], transition[1])) continue; reachable_states[time + 1].insert(transition[2]); } } @@ -216,6 +246,7 @@ std::function TransitionConstraint( std::set new_set; for (const std::vector& transition : automata) { if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!ContainsKey(possible_values[time], transition[1])) continue; if (!ContainsKey(reachable_states[time + 1], transition[2])) continue; new_set.insert(transition[0]); } @@ -240,6 +271,7 @@ std::function TransitionConstraint( std::vector out_states; for (const std::vector& transition : automata) { if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!ContainsKey(possible_values[time], transition[1])) continue; if (!ContainsKey(reachable_states[time + 1], transition[2])) continue; // TODO(user): if this transition correspond to just one in-state or