[CP-SAT] Enable linear expressions in AddMinEquality/AddAbsEquality/AddMaxEquality, and/or add AddLinMinEquality/AddLinMaxEquality; detect more encoding and remove variables only used in encodings

This commit is contained in:
Laurent Perron
2021-10-05 17:20:38 +02:00
parent 7fc810118e
commit ef59381229
11 changed files with 433 additions and 374 deletions

View File

@@ -24,6 +24,7 @@ import com.google.ortools.sat.DecisionStrategyProto;
import com.google.ortools.sat.ElementConstraintProto;
import com.google.ortools.sat.IntegerArgumentProto;
import com.google.ortools.sat.InverseConstraintProto;
import com.google.ortools.sat.LinearArgumentProto;
import com.google.ortools.sat.LinearConstraintProto;
import com.google.ortools.sat.LinearExpressionProto;
import com.google.ortools.sat.NoOverlap2DConstraintProto;
@@ -645,23 +646,45 @@ public final class CpModel {
}
/** Adds {@code target == Min(vars)}. */
public Constraint addMinEquality(IntVar target, IntVar[] vars) {
public Constraint addMinEquality(LinearExpr target, IntVar[] vars) {
Constraint ct = new Constraint(modelBuilder);
IntegerArgumentProto.Builder intMin =
ct.getBuilder().getIntMinBuilder().setTarget(target.getIndex());
LinearArgumentProto.Builder linMax = ct.getBuilder().getLinMaxBuilder();
linMax.setTarget(getLinearExpressionProtoBuilderFromLinearExpr(target, /*negate=*/true));
for (IntVar var : vars) {
intMin.addVars(var.getIndex());
linMax.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(var, /*negate=*/true));
}
return ct;
}
/** Adds {@code target == Min(exprs)}. */
public Constraint addLinMinEquality(LinearExpr target, LinearExpr[] exprs) {
Constraint ct = new Constraint(modelBuilder);
LinearArgumentProto.Builder linMax = ct.getBuilder().getLinMaxBuilder();
linMax.setTarget(getLinearExpressionProtoBuilderFromLinearExpr(target, /*negate=*/true));
for (LinearExpr expr : exprs) {
linMax.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(expr, /*negate=*/true));
}
return ct;
}
/** Adds {@code target == Max(vars)}. */
public Constraint addMaxEquality(IntVar target, IntVar[] vars) {
public Constraint addMaxEquality(LinearExpr target, IntVar[] vars) {
Constraint ct = new Constraint(modelBuilder);
IntegerArgumentProto.Builder intMax =
ct.getBuilder().getIntMaxBuilder().setTarget(target.getIndex());
LinearArgumentProto.Builder linMax = ct.getBuilder().getLinMaxBuilder();
linMax.setTarget(getLinearExpressionProtoBuilderFromLinearExpr(target, /*negate=*/false));
for (IntVar var : vars) {
intMax.addVars(var.getIndex());
linMax.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(var, /*negate=*/false));
}
return ct;
}
/** Adds {@code target == Max(exprs)}. */
public Constraint addLinMaxEquality(LinearExpr target, LinearExpr[] exprs) {
Constraint ct = new Constraint(modelBuilder);
LinearArgumentProto.Builder linMax = ct.getBuilder().getLinMaxBuilder();
linMax.setTarget(getLinearExpressionProtoBuilderFromLinearExpr(target, /*negate=*/false));
for (LinearExpr expr : exprs) {
linMax.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(expr, /*negate=*/false));
}
return ct;
}
@@ -677,14 +700,13 @@ public final class CpModel {
return ct;
}
/** Adds {@code target == Abs(var)}. */
public Constraint addAbsEquality(IntVar target, IntVar var) {
/** Adds {@code target == Abs(expr)}. */
public Constraint addAbsEquality(LinearExpr target, LinearExpr expr) {
Constraint ct = new Constraint(modelBuilder);
ct.getBuilder()
.getIntMaxBuilder()
.setTarget(target.getIndex())
.addVars(var.getIndex())
.addVars(-var.getIndex() - 1);
LinearArgumentProto.Builder linMax = ct.getBuilder().getLinMaxBuilder();
linMax.setTarget(getLinearExpressionProtoBuilderFromLinearExpr(target, /*negate=*/false));
linMax.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(expr, /*negate=*/false));
linMax.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(expr, /*negate=*/true));
return ct;
}
@@ -740,9 +762,10 @@ public final class CpModel {
public IntervalVar newIntervalVar(
LinearExpr start, LinearExpr size, LinearExpr end, String name) {
addEquality(new Sum(start, size), end);
return new IntervalVar(modelBuilder, getLinearExpressionProtoBuilderFromLinearExpr(start),
getLinearExpressionProtoBuilderFromLinearExpr(size),
getLinearExpressionProtoBuilderFromLinearExpr(end), name);
return new IntervalVar(modelBuilder,
getLinearExpressionProtoBuilderFromLinearExpr(start, /*negate=*/false),
getLinearExpressionProtoBuilderFromLinearExpr(size, /*negate=*/false),
getLinearExpressionProtoBuilderFromLinearExpr(end, /*negate=*/false), name);
}
/**
@@ -757,9 +780,11 @@ public final class CpModel {
* @return An IntervalVar object
*/
public IntervalVar newFixedSizeIntervalVar(LinearExpr start, long size, String name) {
return new IntervalVar(modelBuilder, getLinearExpressionProtoBuilderFromLinearExpr(start),
return new IntervalVar(modelBuilder,
getLinearExpressionProtoBuilderFromLinearExpr(start, /*negate=*/false),
getLinearExpressionProtoBuilderFromLong(size),
getLinearExpressionProtoBuilderFromLinearExpr(new Sum(start, size)), name);
getLinearExpressionProtoBuilderFromLinearExpr(new Sum(start, size), /*negate=*/false),
name);
}
/** Creates a fixed interval from its start and its size. */
@@ -790,9 +815,11 @@ public final class CpModel {
public IntervalVar newOptionalIntervalVar(
LinearExpr start, LinearExpr size, LinearExpr end, Literal isPresent, String name) {
addEquality(new Sum(start, size), end).onlyEnforceIf(isPresent);
return new IntervalVar(modelBuilder, getLinearExpressionProtoBuilderFromLinearExpr(start),
getLinearExpressionProtoBuilderFromLinearExpr(size),
getLinearExpressionProtoBuilderFromLinearExpr(end), isPresent.getIndex(), name);
return new IntervalVar(modelBuilder,
getLinearExpressionProtoBuilderFromLinearExpr(start, /*negate=*/false),
getLinearExpressionProtoBuilderFromLinearExpr(size, /*negate=*/false),
getLinearExpressionProtoBuilderFromLinearExpr(end, /*negate=*/false), isPresent.getIndex(),
name);
}
/**
@@ -810,10 +837,11 @@ public final class CpModel {
*/
public IntervalVar newOptionalFixedSizeIntervalVar(
LinearExpr start, long size, Literal isPresent, String name) {
return new IntervalVar(modelBuilder, getLinearExpressionProtoBuilderFromLinearExpr(start),
return new IntervalVar(modelBuilder,
getLinearExpressionProtoBuilderFromLinearExpr(start, /*negate=*/false),
getLinearExpressionProtoBuilderFromLong(size),
getLinearExpressionProtoBuilderFromLinearExpr(new Sum(start, size)), isPresent.getIndex(),
name);
getLinearExpressionProtoBuilderFromLinearExpr(new Sum(start, size), /*negate=*/false),
isPresent.getIndex(), name);
}
/** Creates an optional fixed interval from start and size, and an isPresent literal. */
@@ -1063,14 +1091,16 @@ public final class CpModel {
return constVar.getIndex();
}
LinearExpressionProto.Builder getLinearExpressionProtoBuilderFromLinearExpr(LinearExpr expr) {
LinearExpressionProto.Builder getLinearExpressionProtoBuilderFromLinearExpr(
LinearExpr expr, boolean negate) {
LinearExpressionProto.Builder builder = LinearExpressionProto.newBuilder();
final int numVariables = expr.numElements();
final long mult = negate ? -1 : 1;
for (int i = 0; i < numVariables; ++i) {
builder.addVars(expr.getVariable(i).getIndex());
builder.addCoeffs(expr.getCoefficient(i));
builder.addCoeffs(expr.getCoefficient(i) * mult);
}
builder.setOffset(expr.getOffset());
builder.setOffset(expr.getOffset() * mult);
return builder;
}

View File

@@ -353,6 +353,16 @@ void CumulativeConstraint::AddDemand(IntervalVar interval, IntVar demand) {
builder_->GetOrCreateIntegerIndex(demand.index_));
}
void CumulativeConstraint::AddDemandWithEnergy(IntervalVar interval,
IntVar demand,
const LinearExpr& energy) {
proto_->mutable_cumulative()->add_intervals(interval.index_);
proto_->mutable_cumulative()->add_demands(
builder_->GetOrCreateIntegerIndex(demand.index_));
*proto_->mutable_cumulative()->add_energies() =
builder_->LinearExprToProto(energy);
}
IntervalVar::IntervalVar() : cp_model_(nullptr), index_() {}
IntervalVar::IntervalVar(int index, CpModelProto* cp_model)
@@ -488,9 +498,9 @@ IntervalVar CpModelBuilder::NewOptionalIntervalVar(const LinearExpr& start,
ConstraintProto* const ct = cp_model_.add_constraints();
ct->add_enforcement_literal(presence.index_);
IntervalConstraintProto* const interval = ct->mutable_interval();
LinearExprToProto(start, interval->mutable_start_view());
LinearExprToProto(size, interval->mutable_size_view());
LinearExprToProto(end, interval->mutable_end_view());
*interval->mutable_start_view() = LinearExprToProto(start);
*interval->mutable_size_view() = LinearExprToProto(size);
*interval->mutable_end_view() = LinearExprToProto(end);
return IntervalVar(index, &cp_model_);
}
@@ -500,9 +510,9 @@ IntervalVar CpModelBuilder::NewOptionalFixedSizeIntervalVar(
ConstraintProto* const ct = cp_model_.add_constraints();
ct->add_enforcement_literal(presence.index_);
IntervalConstraintProto* const interval = ct->mutable_interval();
LinearExprToProto(start, interval->mutable_start_view());
*interval->mutable_start_view() = LinearExprToProto(start);
interval->mutable_size_view()->set_offset(size);
LinearExprToProto(start, interval->mutable_end_view());
*interval->mutable_end_view() = LinearExprToProto(start);
interval->mutable_end_view()->set_offset(interval->end_view().offset() +
size);
return IntervalVar(index, &cp_model_);
@@ -722,25 +732,18 @@ AutomatonConstraint CpModelBuilder::AddAutomaton(
return AutomatonConstraint(proto);
}
Constraint CpModelBuilder::AddMinEquality(IntVar target,
absl::Span<const IntVar> vars) {
ConstraintProto* const proto = cp_model_.add_constraints();
proto->mutable_int_min()->set_target(GetOrCreateIntegerIndex(target.index_));
for (const IntVar& var : vars) {
proto->mutable_int_min()->add_vars(GetOrCreateIntegerIndex(var.index_));
}
return Constraint(proto);
}
void CpModelBuilder::LinearExprToProto(const LinearExpr& expr,
LinearExpressionProto* expr_proto) {
LinearExpressionProto CpModelBuilder::LinearExprToProto(const LinearExpr& expr,
bool negate) {
LinearExpressionProto expr_proto;
const int64_t mult = negate ? -1 : 1;
for (const IntVar var : expr.variables()) {
expr_proto->add_vars(GetOrCreateIntegerIndex(var.index_));
expr_proto.add_vars(GetOrCreateIntegerIndex(var.index_));
}
for (const int64_t coeff : expr.coefficients()) {
expr_proto->add_coeffs(coeff);
expr_proto.add_coeffs(coeff * mult);
}
expr_proto->set_offset(expr.constant());
expr_proto.set_offset(expr.constant() * mult);
return expr_proto;
}
LinearExpr CpModelBuilder::LinearExprFromProto(
@@ -753,36 +756,48 @@ LinearExpr CpModelBuilder::LinearExprFromProto(
return result;
}
Constraint CpModelBuilder::AddLinMinEquality(
const LinearExpr& target, absl::Span<const LinearExpr> exprs) {
ConstraintProto* const proto = cp_model_.add_constraints();
LinearExprToProto(target, proto->mutable_lin_min()->mutable_target());
for (const LinearExpr& expr : exprs) {
LinearExpressionProto* expr_proto = proto->mutable_lin_min()->add_exprs();
LinearExprToProto(expr, expr_proto);
Constraint CpModelBuilder::AddMinEquality(const LinearExpr& target,
absl::Span<const IntVar> vars) {
ConstraintProto* ct = cp_model_.add_constraints();
*ct->mutable_lin_max()->mutable_target() =
LinearExprToProto(target, /*negate=*/true);
for (const IntVar& var : vars) {
*ct->mutable_lin_max()->add_exprs() =
LinearExprToProto(var, /*negate=*/true);
}
return Constraint(proto);
return Constraint(ct);
}
Constraint CpModelBuilder::AddMaxEquality(IntVar target,
absl::Span<const IntVar> vars) {
ConstraintProto* const proto = cp_model_.add_constraints();
proto->mutable_int_max()->set_target(GetOrCreateIntegerIndex(target.index_));
for (const IntVar& var : vars) {
proto->mutable_int_max()->add_vars(GetOrCreateIntegerIndex(var.index_));
Constraint CpModelBuilder::AddLinMinEquality(
const LinearExpr& target, absl::Span<const LinearExpr> exprs) {
ConstraintProto* ct = cp_model_.add_constraints();
*ct->mutable_lin_max()->mutable_target() =
LinearExprToProto(target, /*negate=*/true);
for (const LinearExpr& expr : exprs) {
*ct->mutable_lin_max()->add_exprs() =
LinearExprToProto(expr, /*negate=*/true);
}
return Constraint(proto);
return Constraint(ct);
}
Constraint CpModelBuilder::AddMaxEquality(const LinearExpr& target,
absl::Span<const IntVar> vars) {
ConstraintProto* ct = cp_model_.add_constraints();
*ct->mutable_lin_max()->mutable_target() = LinearExprToProto(target);
for (const IntVar& var : vars) {
*ct->mutable_lin_max()->add_exprs() = LinearExprToProto(var);
}
return Constraint(ct);
}
Constraint CpModelBuilder::AddLinMaxEquality(
const LinearExpr& target, absl::Span<const LinearExpr> exprs) {
ConstraintProto* const proto = cp_model_.add_constraints();
LinearExprToProto(target, proto->mutable_lin_max()->mutable_target());
ConstraintProto* ct = cp_model_.add_constraints();
*ct->mutable_lin_max()->mutable_target() = LinearExprToProto(target);
for (const LinearExpr& expr : exprs) {
LinearExpressionProto* expr_proto = proto->mutable_lin_max()->add_exprs();
LinearExprToProto(expr, expr_proto);
*ct->mutable_lin_max()->add_exprs() = LinearExprToProto(expr);
}
return Constraint(proto);
return Constraint(ct);
}
Constraint CpModelBuilder::AddDivisionEquality(IntVar target, IntVar numerator,
@@ -795,12 +810,13 @@ Constraint CpModelBuilder::AddDivisionEquality(IntVar target, IntVar numerator,
return Constraint(proto);
}
Constraint CpModelBuilder::AddAbsEquality(IntVar target, IntVar var) {
Constraint CpModelBuilder::AddAbsEquality(const LinearExpr& target,
const LinearExpr& expr) {
ConstraintProto* const proto = cp_model_.add_constraints();
proto->mutable_int_max()->set_target(GetOrCreateIntegerIndex(target.index_));
proto->mutable_int_max()->add_vars(GetOrCreateIntegerIndex(var.index_));
proto->mutable_int_max()->add_vars(
NegatedRef(GetOrCreateIntegerIndex(var.index_)));
*proto->mutable_lin_max()->mutable_target() = LinearExprToProto(target);
*proto->mutable_lin_max()->add_exprs() = LinearExprToProto(expr);
*proto->mutable_lin_max()->add_exprs() =
LinearExprToProto(expr, /*negate=*/true);
return Constraint(proto);
}
@@ -977,14 +993,6 @@ int64_t SolutionIntegerValue(const CpSolverResponse& r,
return result;
}
int64_t SolutionIntegerMin(const CpSolverResponse& r, IntVar x) {
return r.solution(x.index_);
}
int64_t SolutionIntegerMax(const CpSolverResponse& r, IntVar x) {
return r.solution(x.index_);
}
bool SolutionBooleanValue(const CpSolverResponse& r, BoolVar x) {
const int ref = x.index_;
if (RefIsPositive(ref)) {

View File

@@ -201,8 +201,6 @@ class IntVar {
friend class ReservoirConstraint;
friend int64_t SolutionIntegerValue(const CpSolverResponse& r,
const LinearExpr& expr);
friend int64_t SolutionIntegerMin(const CpSolverResponse& r, IntVar x);
friend int64_t SolutionIntegerMax(const CpSolverResponse& r, IntVar x);
IntVar(int index, CpModelProto* cp_model);
@@ -593,12 +591,23 @@ class NoOverlap2DConstraint : public Constraint {
*
* This constraint allows adding fixed or variables demands to the cumulative
* constraint incrementally.
*
* One cannot mix the AddDemand() and AddDemandWithEnergy() APIs in the same
* cumulative API. Either always supply energy info, or never.
*/
class CumulativeConstraint : public Constraint {
public:
/// Adds a pair (interval, demand) to the constraint.
void AddDemand(IntervalVar interval, IntVar demand);
/// Adds a demand with a linear expression representing the energy.
/// The constraint will check that:
/// energy == size(interval) * demand
/// This extra information is redundant, and helps the linear relaxation of
/// the problem.
void AddDemandWithEnergy(IntervalVar interval, IntVar demand,
const LinearExpr& energy);
private:
friend class CpModelBuilder;
@@ -818,14 +827,16 @@ class CpModelBuilder {
absl::Span<const int> final_states);
/// Adds target == min(vars).
Constraint AddMinEquality(IntVar target, absl::Span<const IntVar> vars);
Constraint AddMinEquality(const LinearExpr& target,
absl::Span<const IntVar> vars);
/// Adds target == min(exprs).
Constraint AddLinMinEquality(const LinearExpr& target,
absl::Span<const LinearExpr> exprs);
/// Adds target == max(vars).
Constraint AddMaxEquality(IntVar target, absl::Span<const IntVar> vars);
Constraint AddMaxEquality(const LinearExpr& target,
absl::Span<const IntVar> vars);
/// Adds target == max(exprs).
Constraint AddLinMaxEquality(const LinearExpr& target,
@@ -835,8 +846,8 @@ class CpModelBuilder {
Constraint AddDivisionEquality(IntVar target, IntVar numerator,
IntVar denominator);
/// Adds target == abs(var).
Constraint AddAbsEquality(IntVar target, IntVar var);
/// Adds target == abs(expr).
Constraint AddAbsEquality(const LinearExpr& target, const LinearExpr& expr);
/// Adds target = var % mod.
Constraint AddModuloEquality(IntVar target, IntVar var, IntVar mod);
@@ -929,8 +940,8 @@ class CpModelBuilder {
friend class IntervalVar;
// Fills the 'expr_proto' with the linear expression represented by 'expr'.
void LinearExprToProto(const LinearExpr& expr,
LinearExpressionProto* expr_proto);
LinearExpressionProto LinearExprToProto(const LinearExpr& expr,
bool negate = false);
// Rebuilds a LinearExpr from a LinearExpressionProto.
// This method is a member of CpModelBuilder because it needs to be friend
@@ -959,12 +970,6 @@ class CpModelBuilder {
/// Evaluates the value of an linear expression in a solver response.
int64_t SolutionIntegerValue(const CpSolverResponse& r, const LinearExpr& expr);
/// Returns the min of an integer variable in a solution.
int64_t SolutionIntegerMin(const CpSolverResponse& r, IntVar x);
/// Returns the max of an integer variable in a solution.
int64_t SolutionIntegerMax(const CpSolverResponse& r, IntVar x);
/// Evaluates the value of a Boolean literal in a solver response.
bool SolutionBooleanValue(const CpSolverResponse& r, BoolVar x);

View File

@@ -377,25 +377,21 @@ message ConstraintProto {
// is the same as the sign of vars[0].
IntegerArgumentProto int_mod = 8;
// The int_max constraint forces the target to equal the maximum of all
// variables.
// The deprecated int_max constraint forces the target to equal the maximum
// of all variables.
//
// The lin_max constraint forces the target to equal the maximum of all
// linear expressions.
//
// TODO(user): Remove int_max in favor of lin_max.
IntegerArgumentProto int_max = 9;
IntegerArgumentProto int_max = 9 [deprecated = true];
LinearArgumentProto lin_max = 27;
// The int_min constraint forces the target to equal the minimum of all
// variables.
// The deprecated int_min constraint forces the target to equal the minimum
// of all variables.
//
// The lin_min constraint forces the target to equal the minimum of all
// linear expressions.
//
// TODO(user): Remove int_min in favor of lin_min.
IntegerArgumentProto int_min = 10;
LinearArgumentProto lin_min = 28;
// The deprecated lin_min constraint forces the target to equal the minimum
// of all linear expressions.
IntegerArgumentProto int_min = 10 [deprecated = true];
LinearArgumentProto lin_min = 28 [deprecated = true];
// The int_prod constraint forces the target to equal the product of all
// variables. By convention, because we can just remove term equal to one,

View File

@@ -124,7 +124,7 @@ bool CpModelPresolver::PresolveEnforcementLiteral(ConstraintProto* ct) {
// same polarity.
if (context_->VariableWithCostIsUniqueAndRemovable(literal)) {
const int64_t obj_coeff =
gtl::FindOrDie(context_->ObjectiveMap(), PositiveRef(literal));
context_->ObjectiveMap().at(PositiveRef(literal));
if (RefIsPositive(literal) == (obj_coeff > 0)) {
// It is just more advantageous to set it to false!
context_->UpdateRuleStats("enforcement literal with unique direction");
@@ -364,7 +364,7 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) {
const int enforcement = ct->enforcement_literal(0);
if (context_->VariableWithCostIsUniqueAndRemovable(enforcement)) {
int var = PositiveRef(enforcement);
int64_t obj_coeff = gtl::FindOrDie(context_->ObjectiveMap(), var);
int64_t obj_coeff = context_->ObjectiveMap().at(var);
if (!RefIsPositive(enforcement)) obj_coeff = -obj_coeff;
// The other case where the constraint is redundant is treated elsewhere.
@@ -390,7 +390,7 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) {
context_->tmp_literal_set.clear();
for (const int literal : *literals) {
if (context_->tmp_literal_set.contains(literal)) {
if (!context_->SetLiteralToFalse(literal)) return true;
if (!context_->SetLiteralToFalse(literal)) return false;
context_->UpdateRuleStats(absl::StrCat(name, "duplicate literals"));
}
if (context_->tmp_literal_set.contains(NegatedRef(literal))) {
@@ -398,7 +398,7 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) {
int num_negative = 0;
for (const int other : *literals) {
if (PositiveRef(other) != PositiveRef(literal)) {
if (!context_->SetLiteralToFalse(other)) return true;
if (!context_->SetLiteralToFalse(other)) return false;
context_->UpdateRuleStats(absl::StrCat(name, "x and not(x)"));
} else {
if (other == literal) {
@@ -412,10 +412,10 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) {
// This is tricky for the case where the at most one reduce to (lit,
// not(lit), not(lit)) for instance.
if (num_positive > 1 && !context_->SetLiteralToFalse(literal)) {
return true;
return false;
}
if (num_negative > 1 && !context_->SetLiteralToTrue(literal)) {
return true;
return false;
}
return RemoveConstraint(ct);
}
@@ -431,7 +431,7 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) {
context_->UpdateRuleStats(absl::StrCat(name, "satisfied"));
for (const int other : *literals) {
if (other != literal) {
if (!context_->SetLiteralToFalse(other)) return true;
if (!context_->SetLiteralToFalse(other)) return false;
}
}
return RemoveConstraint(ct);
@@ -499,6 +499,7 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) {
bool CpModelPresolver::PresolveAtMostOne(ConstraintProto* ct) {
if (context_->ModelIsUnsat()) return false;
CHECK(!HasEnforcementLiteral(*ct));
const bool changed = PresolveAtMostOrExactlyOne(ct);
if (ct->constraint_case() != ConstraintProto::kAtMostOne) return changed;
@@ -1456,8 +1457,8 @@ bool CpModelPresolver::ExploitEquivalenceRelations(int c, ConstraintProto* ct) {
return changed;
}
void CpModelPresolver::DivideLinearByGcd(ConstraintProto* ct) {
if (context_->ModelIsUnsat()) return;
bool CpModelPresolver::DivideLinearByGcd(ConstraintProto* ct) {
if (context_->ModelIsUnsat()) return false;
// Compute the GCD of all coefficients.
int64_t gcd = 0;
@@ -1475,9 +1476,10 @@ void CpModelPresolver::DivideLinearByGcd(ConstraintProto* ct) {
const Domain rhs = ReadDomainFromProto(ct->linear());
FillDomainInProto(rhs.InverseMultiplicationBy(gcd), ct->mutable_linear());
if (ct->linear().domain_size() == 0) {
return (void)MarkConstraintAsFalse(ct);
return MarkConstraintAsFalse(ct);
}
}
return false;
}
template <typename ProtoWithVarsAndCoeffs>
@@ -1583,15 +1585,15 @@ bool CpModelPresolver::CanonicalizeLinear(ConstraintProto* ct) {
}
int64_t offset = 0;
const bool result =
bool changed =
CanonicalizeLinearExpressionInternal(*ct, ct->mutable_linear(), &offset);
if (offset != 0) {
FillDomainInProto(
ReadDomainFromProto(ct->linear()).AdditionWith(Domain(-offset)),
ct->mutable_linear());
}
DivideLinearByGcd(ct);
return result;
changed |= DivideLinearByGcd(ct);
return changed;
}
bool CpModelPresolver::RemoveSingletonInLinear(ConstraintProto* ct) {
@@ -1662,8 +1664,7 @@ bool CpModelPresolver::RemoveSingletonInLinear(ConstraintProto* ct) {
//
// TODO(user): If the objective is a single variable, we can actually
// "absorb" any factor into the objective scaling.
const int64_t objective_coeff =
gtl::FindOrDie(context_->ObjectiveMap(), var);
const int64_t objective_coeff = context_->ObjectiveMap().at(var);
CHECK_NE(coeff, 0);
if (objective_coeff % coeff != 0) continue;
@@ -1933,6 +1934,31 @@ bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) {
}
}
// This is just an implication, lets convert it right away.
if (ct->linear().vars_size() == 1 && ct->enforcement_literal_size() > 0 &&
context_->CanBeUsedAsLiteral(ct->linear().vars(0))) {
const Domain rhs = ReadDomainFromProto(ct->linear());
const bool zero_ok = rhs.Contains(0);
const bool one_ok = rhs.Contains(ct->linear().coeffs(0));
context_->UpdateRuleStats("linear: is boolean implication");
if (!zero_ok && !one_ok) {
return MarkConstraintAsFalse(ct);
}
if (zero_ok && one_ok) {
return RemoveConstraint(ct);
}
const int ref = ct->linear().vars(0);
if (zero_ok) {
ct->mutable_bool_and()->add_literals(NegatedRef(ref));
} else {
ct->mutable_bool_and()->add_literals(ref);
}
// No var <-> constraint graph changes.
// But this is no longer a linear1.
return true;
}
// If the constraint is literal => x in domain and x = abs(abs_arg), we can
// replace x by abs_arg and hopefully remove the variable x later.
int abs_arg;
@@ -2206,7 +2232,7 @@ bool CpModelPresolver::DetectAndProcessOneSidedLinearConstraint(
const int size =
context_->VarToConstraints(var).size() - (is_in_objective ? 1 : 0);
const int64_t obj_coeff =
is_in_objective ? gtl::FindOrDie(context_->ObjectiveMap(), var) : 0;
is_in_objective ? context_->ObjectiveMap().at(var) : 0;
// We cannot fix anything if the domain of the objective is excluding
// some objective values.
@@ -2343,7 +2369,7 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index,
// variable altogether.
if (rhs.Min() != rhs.Max() &&
context_->VariableWithCostIsUniqueAndRemovable(var)) {
const int64_t obj_coeff = gtl::FindOrDie(context_->ObjectiveMap(), var);
const int64_t obj_coeff = context_->ObjectiveMap().at(var);
const bool same_sign = (var_coeff > 0) == (obj_coeff > 0);
bool fixed = false;
if (same_sign && RhsCanBeFixedToMin(var_coeff, context_->DomainOf(var),
@@ -6518,16 +6544,18 @@ void CpModelPresolver::TryToSimplifyDomain(int var) {
// TODO(user): The hint might get lost if the encoding was created during
// the presolve.
if (context_->VariableIsRemovable(var) &&
!context_->CanBeUsedAsLiteral(var) &&
context_->VariableIsOnlyUsedInEncodingAndMaybeInObjective(var) &&
context_->params().search_branching() != SatParameters::FIXED_SEARCH) {
// Detect the full encoding case without extra constraint.
// This is the simplest to deal with as we can just add an exactly one
// constraint and remove all the linear1.
std::vector<int> literals;
std::vector<int> equality_constraints;
std::vector<int> other_constraints;
absl::flat_hash_map<int64_t, int> value_to_equal_literal;
absl::flat_hash_map<int64_t, int> value_to_not_equal_literal;
// We can currently only deal with the case where all encoding constraint
// are of the form literal => var ==/!= value.
// If they are more complex linear1 involved, we just abort.
//
// TODO(user): Also deal with the case all >= or <= where we can add a
// serie of implication between all involved literals.
absl::flat_hash_set<int64_t> values_set;
absl::flat_hash_map<int64_t, std::vector<int>> value_to_equal_literals;
absl::flat_hash_map<int64_t, std::vector<int>> value_to_not_equal_literals;
bool abort = false;
for (const int c : context_->VarToConstraints(var)) {
if (c < 0) continue;
@@ -6544,14 +6572,15 @@ void CpModelPresolver::TryToSimplifyDomain(int var) {
ReadDomainFromProto(ct.linear()).InverseMultiplicationBy(coeff);
if (rhs.IsFixed()) {
const int64_t value = rhs.FixedValue();
if (value_to_equal_literal.contains(value)) {
abort = true;
break;
if (!context_->DomainOf(var).Contains(rhs.FixedValue())) {
if (!context_->SetLiteralToFalse(ct.enforcement_literal(0))) {
return;
}
} else {
values_set.insert(rhs.FixedValue());
value_to_equal_literals[rhs.FixedValue()].push_back(
ct.enforcement_literal(0));
}
equality_constraints.push_back(c);
literals.push_back(ct.enforcement_literal(0));
value_to_equal_literal[value] = ct.enforcement_literal(0);
} else {
const Domain complement =
context_->DomainOf(var).IntersectionWith(rhs.Complement());
@@ -6561,41 +6590,58 @@ void CpModelPresolver::TryToSimplifyDomain(int var) {
break;
}
if (complement.IsFixed()) {
const int64_t value = complement.FixedValue();
if (value_to_not_equal_literal.contains(value)) {
abort = true;
break;
if (context_->DomainOf(var).Contains(complement.FixedValue())) {
values_set.insert(complement.FixedValue());
value_to_not_equal_literals[complement.FixedValue()].push_back(
ct.enforcement_literal(0));
}
other_constraints.push_back(c);
value_to_not_equal_literal[value] = ct.enforcement_literal(0);
} else {
abort = true;
}
}
}
// For a full encoding, we don't need all the not equal constraint to be
// present.
if (value_to_equal_literal.size() != context_->DomainOf(var).Size()) {
abort = true;
} else {
for (const int64_t value : context_->DomainOf(var).Values()) {
if (!value_to_equal_literal.contains(value)) {
abort = true;
break;
}
if (value_to_not_equal_literal.contains(value) &&
value_to_equal_literal[value] !=
NegatedRef(value_to_not_equal_literal[value])) {
abort = true;
break;
}
if (abort) break;
}
}
if (abort) {
context_->UpdateRuleStats("TODO variables: only used in encoding.");
context_->UpdateRuleStats("TODO variables: only used in linear1.");
} else if (value_to_not_equal_literals.empty() &&
value_to_equal_literals.empty()) {
// This is just a variable not used anywhere, it should be removed by
// another part of the presolve.
} else {
// For determinism, sort all the encoded values first.
std::vector<int64_t> encoded_values(values_set.begin(), values_set.end());
std::sort(encoded_values.begin(), encoded_values.end());
CHECK(!encoded_values.empty());
const bool is_fully_encoded =
encoded_values.size() == context_->DomainOf(var).Size();
// Link all Boolean in out linear1 to the encoding literals. Note that we
// should hopefully already have detected such literal before and this
// should add trivial implications.
for (const int64_t v : encoded_values) {
const int encoding_lit = context_->GetOrCreateVarValueEncoding(var, v);
const auto eq_it = value_to_equal_literals.find(v);
if (eq_it != value_to_equal_literals.end()) {
for (const int lit : eq_it->second) {
context_->AddImplication(lit, encoding_lit);
}
}
const auto neq_it = value_to_not_equal_literals.find(v);
if (neq_it != value_to_not_equal_literals.end()) {
for (const int lit : neq_it->second) {
context_->AddImplication(lit, NegatedRef(encoding_lit));
}
}
}
context_->UpdateNewConstraintsVariableUsage();
// This is the set of other values.
Domain other_values;
if (!is_fully_encoded) {
other_values = context_->DomainOf(var).IntersectionWith(
Domain::FromValues(encoded_values).Complement());
}
// Update the objective if needed. Note that this operation can fail if
// the new expression result in potential overflow.
if (context_->VarToConstraints(var).contains(kObjectiveConstraint)) {
@@ -6605,24 +6651,32 @@ void CpModelPresolver::TryToSimplifyDomain(int var) {
linear->add_vars(var);
linear->add_coeffs(coeff_in_equality);
std::vector<int64_t> all_values;
for (const auto entry : value_to_equal_literal) {
all_values.push_back(entry.first);
int64_t min_value;
if (is_fully_encoded) {
// We substract the min_value from all coefficients.
// This should reduce the objective size and helps with the bounds.
//
// TODO(user): If the objective coefficient is negative, then we
// should rather substract the max?
min_value = encoded_values[0];
} else {
// Tricky: If the variable is not fully encoded, then when all partial
// encoding literal are false, it must take the "best" value in
// other_values. That depend on the sign of the objective coeff.
//
// We also restrict other value so that the postsolve code below will
// fix the variable to the correct value when this happen.
const int64_t obj_coeff = context_->ObjectiveMap().at(var);
other_values =
Domain(obj_coeff > 0 ? other_values.Min() : other_values.Max());
min_value = other_values.FixedValue();
}
std::sort(all_values.begin(), all_values.end());
// We substract the min_value from all coefficients.
// This should reduce the objective size and helps with the bounds.
//
// TODO(user): If the objective coefficient is negative, then we
// should rather substract the max.
CHECK(!all_values.empty());
const int64_t min_value = all_values[0];
linear->add_domain(-min_value);
linear->add_domain(-min_value);
for (const int64_t value : all_values) {
for (const int64_t value : encoded_values) {
if (value == min_value) continue;
const int enf = value_to_equal_literal.at(value);
const int enf = context_->GetOrCreateVarValueEncoding(var, value);
const int64_t coeff = value - min_value;
if (RefIsPositive(enf)) {
linear->add_vars(enf);
@@ -6638,39 +6692,59 @@ void CpModelPresolver::TryToSimplifyDomain(int var) {
if (!context_->SubstituteVariableInObjective(var, coeff_in_equality,
encoding_ct)) {
context_->UpdateRuleStats(
"TODO variables: only used in objective and in full encoding");
"TODO variables: only used in objective and in encoding");
return;
}
context_->UpdateRuleStats(
"variables: only used in objective and in full encoding");
"variables: only used in objective and in encoding");
} else {
context_->UpdateRuleStats("variables: only used in full encoding");
context_->UpdateRuleStats("variables: only used in encoding");
}
// Move the encoding constraints to the mapping model. Note that only the
// equality constraint are needed. In fact if we add the other ones, our
// current limited postsolve code will not work.
for (const int c : equality_constraints) {
*context_->mapping_model->add_constraints() =
context_->working_model->constraints(c);
// Clear all involved constraint.
auto copy = context_->VarToConstraints(var);
for (const int c : copy) {
if (c < 0) continue;
context_->working_model->mutable_constraints(c)->Clear();
context_->UpdateConstraintVariableUsage(c);
}
for (const int c : other_constraints) {
context_->working_model->mutable_constraints(c)->Clear();
context_->UpdateConstraintVariableUsage(c);
// Add enough constraints to the mapping model to recover a valid value
// for var when all the booleans are fixed.
for (const int64_t value : encoded_values) {
const int enf = context_->GetOrCreateVarValueEncoding(var, value);
ConstraintProto* ct = context_->mapping_model->add_constraints();
ct->add_enforcement_literal(enf);
ct->mutable_linear()->add_vars(var);
ct->mutable_linear()->add_coeffs(1);
ct->mutable_linear()->add_domain(value);
ct->mutable_linear()->add_domain(value);
}
// This must be done after we removed all the constraint containing var.
ConstraintProto* new_ct = context_->working_model->add_constraints();
std::sort(literals.begin(), literals.end()); // For determinism.
for (const int literal : literals) {
new_ct->mutable_exactly_one()->add_literals(literal);
}
if (is_fully_encoded) {
// The encoding is full: add an exactly one.
for (const int64_t value : encoded_values) {
new_ct->mutable_exactly_one()->add_literals(
context_->GetOrCreateVarValueEncoding(var, value));
}
PresolveExactlyOne(new_ct);
} else {
// If all literal are false, then var must take one of the other values.
ConstraintProto* mapping_ct =
context_->mapping_model->add_constraints();
mapping_ct->mutable_linear()->add_vars(var);
mapping_ct->mutable_linear()->add_coeffs(1);
FillDomainInProto(other_values, mapping_ct->mutable_linear());
// In some cases there is duplicate literal, and we want to make sure
// this is presolved.
PresolveExactlyOne(new_ct);
for (const int64_t value : encoded_values) {
const int literal = context_->GetOrCreateVarValueEncoding(var, value);
mapping_ct->add_enforcement_literal(NegatedRef(literal));
new_ct->mutable_at_most_one()->add_literals(literal);
}
PresolveAtMostOne(new_ct);
}
context_->UpdateNewConstraintsVariableUsage();
context_->MarkVariableAsRemoved(var);

View File

@@ -179,7 +179,8 @@ class CpModelPresolver {
// Extracts AtMostOne constraint from Linear constraint.
void ExtractAtMostOneFromLinear(ConstraintProto* ct);
void DivideLinearByGcd(ConstraintProto* ct);
// Returns true if the constraint changed.
bool DivideLinearByGcd(ConstraintProto* ct);
void ExtractEnforcementLiteralFromLinearConstraint(int ct_index,
ConstraintProto* ct);

View File

@@ -1644,12 +1644,15 @@ void MinimizeL1DistanceWithHint(const CpModelProto& model_proto, Model* model) {
std::max(std::abs(min_domain), std::abs(max_domain));
abs_var_proto->add_domain(abs_min_domain);
abs_var_proto->add_domain(abs_max_domain);
ConstraintProto* const abs_constraint_proto =
updated_model_proto.add_constraints();
abs_constraint_proto->mutable_int_max()->set_target(abs_var_index);
abs_constraint_proto->mutable_int_max()->add_vars(new_var_index);
abs_constraint_proto->mutable_int_max()->add_vars(
NegatedRef(new_var_index));
auto* abs_ct = updated_model_proto.add_constraints()->mutable_lin_max();
abs_ct->mutable_target()->add_vars(abs_var_index);
abs_ct->mutable_target()->add_coeffs(1);
LinearExpressionProto* left = abs_ct->add_exprs();
left->add_vars(new_var_index);
left->add_coeffs(1);
LinearExpressionProto* right = abs_ct->add_exprs();
right->add_vars(new_var_index);
right->add_coeffs(-1);
updated_model_proto.mutable_objective()->add_vars(abs_var_index);
updated_model_proto.mutable_objective()->add_coeffs(1);

View File

@@ -415,29 +415,55 @@ namespace Google.OrTools.Sat
return ct;
}
public Constraint AddMinEquality(IntVar target, IEnumerable<IntVar> vars)
public Constraint AddMinEquality(LinearExpr target, IEnumerable<IntVar> vars)
{
Constraint ct = new Constraint(model_);
IntegerArgumentProto args = new IntegerArgumentProto();
LinearArgumentProto args = new LinearArgumentProto();
foreach (IntVar var in vars)
{
args.Vars.Add(var.Index);
args.Exprs.Add(GetLinearExpressionProto(var, /*negate=*/true));
}
args.Target = target.Index;
ct.Proto.IntMin = args;
args.Target = GetLinearExpressionProto(target, /*negate=*/true);
ct.Proto.LinMax = args;
return ct;
}
public Constraint AddLinMinEquality(LinearExpr target, IEnumerable<LinearExpr> exprs)
{
Constraint ct = new Constraint(model_);
LinearArgumentProto args = new LinearArgumentProto();
foreach (LinearExpr expr in exprs)
{
args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/true));
}
args.Target = GetLinearExpressionProto(target, /*negate=*/true);
ct.Proto.LinMax = args;
return ct;
}
public Constraint AddMaxEquality(IntVar target, IEnumerable<IntVar> vars)
{
Constraint ct = new Constraint(model_);
IntegerArgumentProto args = new IntegerArgumentProto();
LinearArgumentProto args = new LinearArgumentProto();
foreach (IntVar var in vars)
{
args.Vars.Add(var.Index);
args.Exprs.Add(GetLinearExpressionProto(var, /*negate=*/false));
}
args.Target = target.Index;
ct.Proto.IntMax = args;
args.Target = GetLinearExpressionProto(target, /*negate=*/false);
ct.Proto.LinMax = args;
return ct;
}
public Constraint AddLinMaxEquality(LinearExpr target, IEnumerable<LinearExpr> exprs)
{
Constraint ct = new Constraint(model_);
LinearArgumentProto args = new LinearArgumentProto();
foreach (LinearExpr expr in exprs)
{
args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/false));
}
args.Target = GetLinearExpressionProto(target, /*negate=*/false);
ct.Proto.LinMax = args;
return ct;
}
@@ -452,14 +478,14 @@ namespace Google.OrTools.Sat
return ct;
}
public Constraint AddAbsEquality(IntVar target, IntVar var)
public Constraint AddAbsEquality(LinearExpr target, LinearExpr expr)
{
Constraint ct = new Constraint(model_);
IntegerArgumentProto args = new IntegerArgumentProto();
args.Vars.Add(var.Index);
args.Vars.Add(-var.Index - 1);
args.Target = target.Index;
ct.Proto.IntMax = args;
LinearArgumentProto args = new LinearArgumentProto();
args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/false));
args.Exprs.Add(GetLinearExpressionProto(expr, /*negate=*/true));
args.Target = GetLinearExpressionProto(target, /*negate=*/false);
ct.Proto.LinMax = args;
return ct;
}
@@ -501,9 +527,9 @@ namespace Google.OrTools.Sat
LinearExpr endExpr = GetLinearExpr(end);
Add(startExpr + durationExpr == endExpr);
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr);
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr, /*negate=*/false);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr, /*negate=*/false);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr, /*negate=*/false);
return new IntervalVar(model_, startProto, durationProto, endProto, name);
}
@@ -513,9 +539,9 @@ namespace Google.OrTools.Sat
LinearExpr durationExpr = GetLinearExpr(duration);
LinearExpr endExpr = LinearExpr.Sum(new LinearExpr[] { startExpr, durationExpr });
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr);
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr, /*negate=*/false);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr, /*negate=*/false);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr, /*negate=*/false);
return new IntervalVar(model_, startProto, durationProto, endProto, name);
}
@@ -526,9 +552,9 @@ namespace Google.OrTools.Sat
LinearExpr endExpr = GetLinearExpr(end);
Add(startExpr + durationExpr == endExpr).OnlyEnforceIf(is_present);
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr);
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr, /*negate=*/false);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr, /*negate=*/false);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr, /*negate=*/false);
return new IntervalVar(model_, startProto, durationProto, endProto, is_present.GetIndex(), name);
}
@@ -538,9 +564,9 @@ namespace Google.OrTools.Sat
LinearExpr durationExpr = GetLinearExpr(duration);
LinearExpr endExpr = LinearExpr.Sum(new LinearExpr[] { startExpr, durationExpr });
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr);
LinearExpressionProto startProto = GetLinearExpressionProto(startExpr, /*negate=*/false);
LinearExpressionProto durationProto = GetLinearExpressionProto(durationExpr, /*negate=*/false);
LinearExpressionProto endProto = GetLinearExpressionProto(endExpr, /*negate=*/false);
return new IntervalVar(model_, startProto, durationProto, endProto, is_present.GetIndex(), name);
}
@@ -791,17 +817,18 @@ namespace Google.OrTools.Sat
throw new ArgumentException("Cannot convert argument to LinearExpr");
}
private LinearExpressionProto GetLinearExpressionProto(LinearExpr expr)
private LinearExpressionProto GetLinearExpressionProto(LinearExpr expr, bool negate)
{
Dictionary<IntVar, long> dict = new Dictionary<IntVar, long>();
long constant = LinearExpr.GetVarValueMap(expr, 1L, dict);
long mult = negate ? -1 : 1;
LinearExpressionProto linear = new LinearExpressionProto();
foreach (KeyValuePair<IntVar, long> term in dict)
{
linear.Vars.Add(term.Key.Index);
linear.Coeffs.Add(term.Value);
linear.Coeffs.Add(term.Value * mult);
}
linear.Offset = constant;
linear.Offset = constant * mult;
return linear;
}

View File

@@ -793,8 +793,6 @@ bool PresolveContext::StoreAffineRelation(int ref_x, int ref_y, int64_t coeff,
// These maps should only contains representative, so only need to remap
// either x or y.
const int rep = GetAffineRelation(x).representative;
if (x != rep) encoding_remap_queue_.push_back(x);
if (y != rep) encoding_remap_queue_.push_back(y);
// The domain didn't change, but this notification allows to re-process any
// constraint containing these variables. Note that we do not need to
@@ -964,89 +962,6 @@ void PresolveContext::InitializeNewDomains() {
var_to_lb_only_constraints.resize(domains.size());
}
bool PresolveContext::RemapEncodingMaps() {
// TODO(user): for now, while the code works most of the time, it triggers
// weird side effect that causes some issues in some LNS presolve...
// We should continue the investigation before activating it.
//
// Note also that because all our encoding constraints are present in the
// model, they will be remapped, and the new mapping re-added again. So while
// the current code might not be efficient, it should eventually reach the
// same effect.
encoding_remap_queue_.clear();
// Note that InsertVarValueEncodingInternal() will potentially add new entry
// to the encoding_ map, but for a different variables. So this code relies on
// the fact that the var_map shouldn't change content nor address of the
// "var_map" below while we iterate on them.
for (const int var : encoding_remap_queue_) {
CHECK(RefIsPositive(var));
const AffineRelation::Relation r = GetAffineRelation(var);
if (r.representative == var) return true;
int num_remapping = 0;
// Encoding.
{
const absl::flat_hash_map<int64_t, SavedLiteral>& var_map =
encoding_[var];
for (const auto& entry : var_map) {
const int lit = entry.second.Get(this);
if (removed_variables_.contains(PositiveRef(lit))) continue;
if ((entry.first - r.offset) % r.coeff != 0) continue;
const int64_t rep_value = (entry.first - r.offset) / r.coeff;
++num_remapping;
InsertVarValueEncodingInternal(lit, r.representative, rep_value,
/*add_constraints=*/false);
if (is_unsat_) return false;
}
encoding_.erase(var);
}
// Eq half encoding.
{
const absl::flat_hash_map<int64_t, absl::flat_hash_set<int>>& var_map =
eq_half_encoding_[var];
for (const auto& entry : var_map) {
if ((entry.first - r.offset) % r.coeff != 0) continue;
const int64_t rep_value = (entry.first - r.offset) / r.coeff;
for (int literal : entry.second) {
++num_remapping;
InsertHalfVarValueEncoding(GetLiteralRepresentative(literal),
r.representative, rep_value,
/*imply_eq=*/true);
if (is_unsat_) return false;
}
}
eq_half_encoding_.erase(var);
}
// Neq half encoding.
{
const absl::flat_hash_map<int64_t, absl::flat_hash_set<int>>& var_map =
neq_half_encoding_[var];
for (const auto& entry : var_map) {
if ((entry.first - r.offset) % r.coeff != 0) continue;
const int64_t rep_value = (entry.first - r.offset) / r.coeff;
for (int literal : entry.second) {
++num_remapping;
InsertHalfVarValueEncoding(GetLiteralRepresentative(literal),
r.representative, rep_value,
/*imply_eq=*/false);
if (is_unsat_) return false;
}
}
neq_half_encoding_.erase(var);
}
if (num_remapping > 0) {
VLOG(1) << "Remapped " << num_remapping << " encodings due to " << var
<< " -> " << r.representative << ".";
}
}
encoding_remap_queue_.clear();
return !is_unsat_;
}
void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) {
CHECK(RefIsPositive(var));
CHECK_EQ(DomainOf(var).Size(), 2);
@@ -1230,7 +1145,6 @@ bool PresolveContext::CanonicalizeEncoding(int* ref, int64_t* value) {
bool PresolveContext::InsertVarValueEncoding(int literal, int ref,
int64_t value) {
if (!RemapEncodingMaps()) return false;
if (!CanonicalizeEncoding(&ref, &value)) {
return SetLiteralToFalse(literal);
}
@@ -1241,7 +1155,6 @@ bool PresolveContext::InsertVarValueEncoding(int literal, int ref,
bool PresolveContext::StoreLiteralImpliesVarEqValue(int literal, int var,
int64_t value) {
if (!RemapEncodingMaps()) return false;
if (!CanonicalizeEncoding(&var, &value)) return false;
literal = GetLiteralRepresentative(literal);
return InsertHalfVarValueEncoding(literal, var, value, /*imply_eq=*/true);
@@ -1249,7 +1162,6 @@ bool PresolveContext::StoreLiteralImpliesVarEqValue(int literal, int var,
bool PresolveContext::StoreLiteralImpliesVarNEqValue(int literal, int var,
int64_t value) {
if (!RemapEncodingMaps()) return false;
if (!CanonicalizeEncoding(&var, &value)) return false;
literal = GetLiteralRepresentative(literal);
return InsertHalfVarValueEncoding(literal, var, value, /*imply_eq=*/false);
@@ -1258,7 +1170,6 @@ bool PresolveContext::StoreLiteralImpliesVarNEqValue(int literal, int var,
bool PresolveContext::HasVarValueEncoding(int ref, int64_t value,
int* literal) {
CHECK(!VariableWasRemoved(ref));
if (!RemapEncodingMaps()) return false;
if (!CanonicalizeEncoding(&ref, &value)) return false;
const absl::flat_hash_map<int64_t, SavedLiteral>& var_map = encoding_[ref];
const auto it = var_map.find(value);
@@ -1272,13 +1183,7 @@ bool PresolveContext::HasVarValueEncoding(int ref, int64_t value,
}
int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) {
// TODO(user): Remove this precondition. For now it is needed because
// we might remove encoding literal without updating the encoding map.
// This is related to RemapEncodingMaps() which is currently disabled.
CHECK(!ModelIsExpanded());
CHECK(!VariableWasRemoved(ref));
if (!RemapEncodingMaps()) return GetOrCreateConstantVar(0);
if (!CanonicalizeEncoding(&ref, &value)) return GetOrCreateConstantVar(0);
// Positive after CanonicalizeEncoding().
@@ -1293,7 +1198,14 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) {
absl::flat_hash_map<int64_t, SavedLiteral>& var_map = encoding_[var];
auto it = var_map.find(value);
if (it != var_map.end()) {
return it->second.Get(this);
const int lit = it->second.Get(this);
if (VariableWasRemoved(lit)) {
// If the variable was already removed, for now we create a new one.
// This should be rare hopefully.
var_map.erase(value);
} else {
return lit;
}
}
// Special case for fixed domains.
@@ -1311,11 +1223,17 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) {
const int64_t other_value = value == var_min ? var_max : var_min;
auto other_it = var_map.find(other_value);
if (other_it != var_map.end()) {
// Update the encoding map. The domain could have been reduced to size
// two after the creation of the first literal.
const int literal = NegatedRef(other_it->second.Get(this));
var_map[value] = SavedLiteral(literal);
return literal;
if (VariableWasRemoved(literal)) {
// If the variable was already removed, for now we create a new one.
// This should be rare hopefully.
var_map.erase(other_value);
} else {
// Update the encoding map. The domain could have been reduced to size
// two after the creation of the first literal.
var_map[value] = SavedLiteral(literal);
return literal;
}
}
if (var_min == 0 && var_max == 1) {

View File

@@ -478,10 +478,6 @@ class PresolveContext {
// class of size at least 2.
bool VariableIsNotRepresentativeOfEquivalenceClass(int var) const;
// Process encoding_remap_queue_ and updates the encoding maps. This could
// lead to UNSAT being detected, in which case it will return false.
bool RemapEncodingMaps();
// Makes sure we only insert encoding about the current representative.
//
// Returns false if ref cannot take the given value (it might not have been
@@ -547,10 +543,6 @@ class PresolveContext {
// equivalence relation. See ExploitFixedDomain().
absl::flat_hash_map<int64_t, SavedVariable> constant_to_ref_;
// When a "representative" gets a new representative, it should be enqueued
// here so that we can lazily update the *encoding_ maps below.
std::deque<int> encoding_remap_queue_;
// Contains variables with some encoded value: encoding_[i][v] points
// to the literal attached to the value v of the variable i.
absl::flat_hash_map<int, absl::flat_hash_map<int64_t, SavedLiteral>>

View File

@@ -1361,22 +1361,24 @@ class CpModel(object):
[self.GetOrMakeBooleanIndex(x) for x in literals])
return ct
def AddMinEquality(self, target, variables):
def AddMinEquality(self, target, exprs):
"""Adds `target == Min(variables)`."""
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
model_ct.int_min.vars.extend(
[self.GetOrMakeIndex(x) for x in variables])
model_ct.int_min.target = self.GetOrMakeIndex(target)
model_ct.lin_max.exprs.extend(
[self.ParseLinearExpression(x, True) for x in exprs])
model_ct.lin_max.target.CopyFrom(
self.ParseLinearExpression(target, True))
return ct
def AddMaxEquality(self, target, variables):
def AddMaxEquality(self, target, exprs):
"""Adds `target == Max(variables)`."""
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
model_ct.int_max.vars.extend(
[self.GetOrMakeIndex(x) for x in variables])
model_ct.int_max.target = self.GetOrMakeIndex(target)
model_ct.lin_max.exprs.extend(
[self.ParseLinearExpression(x, False) for x in exprs])
model_ct.lin_max.target.CopyFrom(
self.ParseLinearExpression(target, False))
return ct
def AddDivisionEquality(self, target, num, denom):
@@ -1389,13 +1391,14 @@ class CpModel(object):
model_ct.int_div.target = self.GetOrMakeIndex(target)
return ct
def AddAbsEquality(self, target, var):
def AddAbsEquality(self, target, expr):
"""Adds `target == Abs(var)`."""
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
index = self.GetOrMakeIndex(var)
model_ct.int_max.vars.extend([index, -index - 1])
model_ct.int_max.target = self.GetOrMakeIndex(target)
model_ct.lin_max.exprs.append(self.ParseLinearExpression(expr, False))
model_ct.lin_max.exprs.append(self.ParseLinearExpression(expr, True))
model_ct.lin_max.target.CopyFrom(
self.ParseLinearExpression(target, False))
return ct
def AddModuloEquality(self, target, var, mod):
@@ -1448,9 +1451,9 @@ class CpModel(object):
self.Add(start + size == end)
start_view = self.ParseLinearExpression(start)
size_view = self.ParseLinearExpression(size)
end_view = self.ParseLinearExpression(end)
start_view = self.ParseLinearExpression(start, False)
size_view = self.ParseLinearExpression(size, False)
end_view = self.ParseLinearExpression(end, False)
if len(start_view.vars) > 1:
raise TypeError(
'cp_model.NewIntervalVar: start must be affine or constant.')
@@ -1479,9 +1482,9 @@ class CpModel(object):
An `IntervalVar` object.
"""
cp_model_helper.AssertIsInt64(size)
start_view = self.ParseLinearExpression(start)
size_view = self.ParseLinearExpression(size)
end_view = self.ParseLinearExpression(start + size)
start_view = self.ParseLinearExpression(start, False)
size_view = self.ParseLinearExpression(size, False)
end_view = self.ParseLinearExpression(start + size, False)
if len(start_view.vars) > 1:
raise TypeError(
'cp_model.NewIntervalVar: start must be affine or constant.')
@@ -1517,9 +1520,9 @@ class CpModel(object):
# Creates the IntervalConstraintProto object.
is_present_index = self.GetOrMakeBooleanIndex(is_present)
start_view = self.ParseLinearExpression(start)
size_view = self.ParseLinearExpression(size)
end_view = self.ParseLinearExpression(end)
start_view = self.ParseLinearExpression(start, False)
size_view = self.ParseLinearExpression(size, False)
end_view = self.ParseLinearExpression(end, False)
if len(start_view.vars) > 1:
raise TypeError(
'cp_model.NewIntervalVar: start must be affine or constant.')
@@ -1550,9 +1553,9 @@ class CpModel(object):
An `IntervalVar` object.
"""
cp_model_helper.AssertIsInt64(size)
start_view = self.ParseLinearExpression(start)
size_view = self.ParseLinearExpression(size)
end_view = self.ParseLinearExpression(start + size)
start_view = self.ParseLinearExpression(start, False)
size_view = self.ParseLinearExpression(size, False)
end_view = self.ParseLinearExpression(start + size, False)
if len(start_view.vars) > 1:
raise TypeError(
'cp_model.NewIntervalVar: start must be affine or constant.')
@@ -1657,7 +1660,8 @@ class CpModel(object):
model_ct.cumulative.demands.extend(
[self.GetOrMakeIndex(x) for x in demands])
for e in energies:
model_ct.cumulative.energies.append(self.ParseLinearExpression(e))
model_ct.cumulative.energies.append(
self.ParseLinearExpression(e, False))
model_ct.cumulative.capacity = self.GetOrMakeIndex(capacity)
return ct
@@ -1766,26 +1770,27 @@ class CpModel(object):
else:
return self.__model.variables[-var_index - 1]
def ParseLinearExpression(self, linear_expr):
def ParseLinearExpression(self, linear_expr, negate):
"""Returns a LinearExpressionProto built from a LinearExpr instance."""
result = cp_model_pb2.LinearExpressionProto()
mult = -1 if negate else 1
if isinstance(linear_expr, numbers.Integral):
result.offset = linear_expr
result.offset = linear_expr * mult
return result
if isinstance(linear_expr, IntVar):
result.vars.append(self.GetOrMakeIndex(linear_expr))
result.coeffs.append(1)
result.coeffs.append(mult)
return result
coeffs_map, constant = linear_expr.GetVarValueMap()
result.offset = constant
result.offset = constant * mult
for t in coeffs_map.items():
if not isinstance(t[0], IntVar):
raise TypeError('Wrong argument' + str(t))
cp_model_helper.AssertIsInt64(t[1])
result.vars.append(t[0].Index())
result.coeffs.append(t[1])
result.coeffs.append(t[1] * mult)
return result
def _SetObjective(self, obj, minimize):