polish code
This commit is contained in:
@@ -374,7 +374,7 @@ EnforcementStatus EnforcementPropagator::DebugStatus(EnforcementId id) {
|
||||
}
|
||||
|
||||
BooleanXorPropagator::BooleanXorPropagator(
|
||||
const std::vector<Literal>& enforcement_literals,
|
||||
absl::Span<const Literal> enforcement_literals,
|
||||
const std::vector<Literal>& literals, bool value, Model* model)
|
||||
: literals_(literals),
|
||||
value_(value),
|
||||
|
||||
@@ -159,7 +159,7 @@ class EnforcementPropagator : public SatPropagator {
|
||||
// faster.
|
||||
class BooleanXorPropagator : public PropagatorInterface {
|
||||
public:
|
||||
BooleanXorPropagator(const std::vector<Literal>& enforcement_literals,
|
||||
BooleanXorPropagator(absl::Span<const Literal> enforcement_literals,
|
||||
const std::vector<Literal>& literals, bool value,
|
||||
Model* model);
|
||||
|
||||
|
||||
@@ -184,7 +184,7 @@ _IndexOrSeries = Union[pd.Index, pd.Series]
|
||||
|
||||
# Helper functions.
|
||||
def snake_case_to_camel_case(name: str) -> str:
|
||||
"""Converts a snake_case name to camelCase."""
|
||||
"""Converts a snake_case name to CamelCase."""
|
||||
words = name.split("_")
|
||||
return (
|
||||
"".join(word.capitalize() for word in words)
|
||||
@@ -900,7 +900,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_bool_or(self, *literals):
|
||||
"""Adds `Or(literals) == true`: sum(literals) >= 1."""
|
||||
return self._add_bool_argument_constraint("or", *literals)
|
||||
return self._add_bool_argument_constraint(
|
||||
cmh.BoolArgumentConstraint.bool_or, *literals
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_at_least_one(self, literals: Iterable[LiteralT]) -> Constraint: ...
|
||||
@@ -910,7 +912,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_at_least_one(self, *literals):
|
||||
"""Same as `add_bool_or`: `sum(literals) >= 1`."""
|
||||
return self._add_bool_argument_constraint("or", *literals)
|
||||
return self._add_bool_argument_constraint(
|
||||
cmh.BoolArgumentConstraint.bool_or, *literals
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_at_most_one(self, literals: Iterable[LiteralT]) -> Constraint: ...
|
||||
@@ -920,7 +924,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_at_most_one(self, *literals) -> Constraint:
|
||||
"""Adds `AtMostOne(literals)`: `sum(literals) <= 1`."""
|
||||
return self._add_bool_argument_constraint("at_most_one", *literals)
|
||||
return self._add_bool_argument_constraint(
|
||||
cmh.BoolArgumentConstraint.at_most_one, *literals
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_exactly_one(self, literals: Iterable[LiteralT]) -> Constraint: ...
|
||||
@@ -930,7 +936,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_exactly_one(self, *literals):
|
||||
"""Adds `ExactlyOne(literals)`: `sum(literals) == 1`."""
|
||||
return self._add_bool_argument_constraint("exactly_one", *literals)
|
||||
return self._add_bool_argument_constraint(
|
||||
cmh.BoolArgumentConstraint.exactly_one, *literals
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_bool_and(self, literals: Iterable[LiteralT]) -> Constraint: ...
|
||||
@@ -940,7 +948,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_bool_and(self, *literals):
|
||||
"""Adds `And(literals) == true`."""
|
||||
return self._add_bool_argument_constraint("and", *literals)
|
||||
return self._add_bool_argument_constraint(
|
||||
cmh.BoolArgumentConstraint.bool_and, *literals
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_bool_xor(self, literals: Iterable[LiteralT]) -> Constraint: ...
|
||||
@@ -960,7 +970,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
Returns:
|
||||
An `Constraint` object.
|
||||
"""
|
||||
return self._add_bool_argument_constraint("xor", *literals)
|
||||
return self._add_bool_argument_constraint(
|
||||
cmh.BoolArgumentConstraint.bool_xor, *literals
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_min_equality(
|
||||
@@ -974,7 +986,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_min_equality(self, target, *expressions) -> Constraint:
|
||||
"""Adds `target == Min(expressions)`."""
|
||||
return self._add_linear_argument_constraint("min", target, *expressions)
|
||||
return self._add_linear_argument_constraint(
|
||||
cmh.LinearArgumentConstraint.min, target, *expressions
|
||||
)
|
||||
|
||||
@overload
|
||||
def add_max_equality(
|
||||
@@ -988,17 +1002,23 @@ class CpModel(cmh.CpBaseModel):
|
||||
|
||||
def add_max_equality(self, target, *expressions) -> Constraint:
|
||||
"""Adds `target == Max(expressions)`."""
|
||||
return self._add_linear_argument_constraint("max", target, *expressions)
|
||||
return self._add_linear_argument_constraint(
|
||||
cmh.LinearArgumentConstraint.max, target, *expressions
|
||||
)
|
||||
|
||||
def add_division_equality(
|
||||
self, target: LinearExprT, num: LinearExprT, denom: LinearExprT
|
||||
) -> Constraint:
|
||||
"""Adds `target == num // denom` (integer division rounded towards 0)."""
|
||||
return self._add_linear_argument_constraint("div", target, [num, denom])
|
||||
return self._add_linear_argument_constraint(
|
||||
cmh.LinearArgumentConstraint.div, target, [num, denom]
|
||||
)
|
||||
|
||||
def add_abs_equality(self, target: LinearExprT, expr: LinearExprT) -> Constraint:
|
||||
"""Adds `target == Abs(expr)`."""
|
||||
return self._add_linear_argument_constraint("max", target, [expr, -expr])
|
||||
return self._add_linear_argument_constraint(
|
||||
cmh.LinearArgumentConstraint.max, target, [expr, -expr]
|
||||
)
|
||||
|
||||
def add_modulo_equality(
|
||||
self, target: LinearExprT, expr: LinearExprT, mod: LinearExprT
|
||||
@@ -1022,7 +1042,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
Returns:
|
||||
A `Constraint` object.
|
||||
"""
|
||||
return self._add_linear_argument_constraint("mod", target, [expr, mod])
|
||||
return self._add_linear_argument_constraint(
|
||||
cmh.LinearArgumentConstraint.mod, target, [expr, mod]
|
||||
)
|
||||
|
||||
def add_multiplication_equality(
|
||||
self,
|
||||
@@ -1030,7 +1052,9 @@ class CpModel(cmh.CpBaseModel):
|
||||
*expressions: Union[Iterable[LinearExprT], LinearExprT],
|
||||
) -> Constraint:
|
||||
"""Adds `target == expressions[0] * .. * expressions[n]`."""
|
||||
return self._add_linear_argument_constraint("prod", target, *expressions)
|
||||
return self._add_linear_argument_constraint(
|
||||
cmh.LinearArgumentConstraint.prod, target, *expressions
|
||||
)
|
||||
|
||||
# Scheduling support
|
||||
|
||||
|
||||
@@ -462,6 +462,22 @@ void LinearExprToProto(const py::handle& arg, int64_t multiplier,
|
||||
class Constraint;
|
||||
class IntervalVar;
|
||||
|
||||
enum class BoolArgumentConstraint {
|
||||
kAtMostOne,
|
||||
kBoolAnd,
|
||||
kBoolOr,
|
||||
kBoolXor,
|
||||
kExactlyOne,
|
||||
};
|
||||
|
||||
enum class LinearArgumentConstraint {
|
||||
kDiv,
|
||||
kMax,
|
||||
kMin,
|
||||
kMod,
|
||||
kProd,
|
||||
};
|
||||
|
||||
class CpBaseModel : public std::enable_shared_from_this<CpBaseModel> {
|
||||
public:
|
||||
CpBaseModel()
|
||||
@@ -573,7 +589,7 @@ class CpBaseModel : public std::enable_shared_from_this<CpBaseModel> {
|
||||
const std::vector<std::vector<int64_t>>& transition_triples);
|
||||
|
||||
std::shared_ptr<Constraint> AddBoolArgumentConstraintInternal(
|
||||
const std::string& name, py::args literals);
|
||||
BoolArgumentConstraint type, py::args literals);
|
||||
|
||||
std::shared_ptr<Constraint> AddBoundedLinearExpressionInternal(
|
||||
BoundedLinearExpression* ble);
|
||||
@@ -586,7 +602,7 @@ class CpBaseModel : public std::enable_shared_from_this<CpBaseModel> {
|
||||
py::sequence inverse);
|
||||
|
||||
std::shared_ptr<Constraint> AddLinearArgumentConstraintInternal(
|
||||
const std::string& name, const py::handle& target, py::args exprs);
|
||||
LinearArgumentConstraint type, const py::handle& target, py::args exprs);
|
||||
|
||||
std::shared_ptr<Constraint> AddReservoirInternal(py::sequence times,
|
||||
py::sequence level_changes,
|
||||
@@ -702,23 +718,29 @@ std::shared_ptr<Constraint> CpBaseModel::AddAutomatonInternal(
|
||||
}
|
||||
|
||||
std::shared_ptr<Constraint> CpBaseModel::AddBoolArgumentConstraintInternal(
|
||||
const std::string& name, py::args literals) {
|
||||
BoolArgumentConstraint type, py::args literals) {
|
||||
const int ct_index = model_proto_->constraints_size();
|
||||
ConstraintProto* ct = model_proto_->add_constraints();
|
||||
BoolArgumentProto* proto = nullptr;
|
||||
if (name == "or") {
|
||||
proto = ct->mutable_bool_or();
|
||||
} else if (name == "and") {
|
||||
proto = ct->mutable_bool_and();
|
||||
} else if (name == "xor") {
|
||||
proto = ct->mutable_bool_xor();
|
||||
} else if (name == "at_most_one") {
|
||||
proto = ct->mutable_at_most_one();
|
||||
} else if (name == "exactly_one") {
|
||||
proto = ct->mutable_exactly_one();
|
||||
} else {
|
||||
ThrowError(PyExc_ValueError,
|
||||
absl::StrCat("Unknown boolean argument constraint: ", name));
|
||||
switch (type) {
|
||||
case BoolArgumentConstraint::kAtMostOne:
|
||||
proto = ct->mutable_at_most_one();
|
||||
break;
|
||||
case BoolArgumentConstraint::kBoolAnd:
|
||||
proto = ct->mutable_bool_and();
|
||||
break;
|
||||
case BoolArgumentConstraint::kBoolOr:
|
||||
proto = ct->mutable_bool_or();
|
||||
break;
|
||||
case BoolArgumentConstraint::kBoolXor:
|
||||
proto = ct->mutable_bool_xor();
|
||||
break;
|
||||
case BoolArgumentConstraint::kExactlyOne:
|
||||
proto = ct->mutable_exactly_one();
|
||||
break;
|
||||
default:
|
||||
ThrowError(PyExc_ValueError,
|
||||
absl::StrCat("Unknown boolean argument constraint: ", type));
|
||||
}
|
||||
if (literals.size() == 1 && py::isinstance<py::iterable>(literals[0])) {
|
||||
for (const auto& literal : literals[0]) {
|
||||
@@ -781,25 +803,31 @@ std::shared_ptr<Constraint> CpBaseModel::AddInverseInternal(
|
||||
}
|
||||
|
||||
std::shared_ptr<Constraint> CpBaseModel::AddLinearArgumentConstraintInternal(
|
||||
const std::string& name, const py::handle& target, py::args exprs) {
|
||||
LinearArgumentConstraint type, const py::handle& target, py::args exprs) {
|
||||
const int ct_index = model_proto_->constraints_size();
|
||||
ConstraintProto* ct = model_proto_->add_constraints();
|
||||
LinearArgumentProto* proto;
|
||||
int64_t multiplier = 1;
|
||||
if (name == "min") {
|
||||
proto = ct->mutable_lin_max();
|
||||
multiplier = -1;
|
||||
} else if (name == "max") {
|
||||
proto = ct->mutable_lin_max();
|
||||
} else if (name == "prod") {
|
||||
proto = ct->mutable_int_prod();
|
||||
} else if (name == "div") {
|
||||
proto = ct->mutable_int_div();
|
||||
} else if (name == "mod") {
|
||||
proto = ct->mutable_int_mod();
|
||||
} else {
|
||||
ThrowError(PyExc_ValueError,
|
||||
absl::StrCat("Unknown integer argument constraint: ", name));
|
||||
switch (type) {
|
||||
case LinearArgumentConstraint::kDiv:
|
||||
proto = ct->mutable_int_div();
|
||||
break;
|
||||
case LinearArgumentConstraint::kMax:
|
||||
proto = ct->mutable_lin_max();
|
||||
break;
|
||||
case LinearArgumentConstraint::kMin:
|
||||
proto = ct->mutable_lin_max();
|
||||
multiplier = -1;
|
||||
break;
|
||||
case LinearArgumentConstraint::kMod:
|
||||
proto = ct->mutable_int_mod();
|
||||
break;
|
||||
case LinearArgumentConstraint::kProd:
|
||||
proto = ct->mutable_int_prod();
|
||||
break;
|
||||
default:
|
||||
ThrowError(PyExc_ValueError,
|
||||
absl::StrCat("Unknown integer argument constraint: ", type));
|
||||
}
|
||||
|
||||
LinearExprToProto(target, multiplier, proto->mutable_target());
|
||||
@@ -1871,6 +1899,22 @@ PYBIND11_MODULE(cp_model_helper, m) {
|
||||
return false;
|
||||
});
|
||||
|
||||
py::enum_<BoolArgumentConstraint>(m, "BoolArgumentConstraint")
|
||||
.value("at_most_one", BoolArgumentConstraint::kAtMostOne)
|
||||
.value("bool_and", BoolArgumentConstraint::kBoolAnd)
|
||||
.value("bool_or", BoolArgumentConstraint::kBoolOr)
|
||||
.value("bool_xor", BoolArgumentConstraint::kBoolXor)
|
||||
.value("exactly_one", BoolArgumentConstraint::kExactlyOne)
|
||||
.export_values();
|
||||
|
||||
py::enum_<LinearArgumentConstraint>(m, "LinearArgumentConstraint")
|
||||
.value("div", LinearArgumentConstraint::kDiv)
|
||||
.value("max", LinearArgumentConstraint::kMax)
|
||||
.value("min", LinearArgumentConstraint::kMin)
|
||||
.value("mod", LinearArgumentConstraint::kMod)
|
||||
.value("prod", LinearArgumentConstraint::kProd)
|
||||
.export_values();
|
||||
|
||||
py::class_<CpBaseModel, std::shared_ptr<CpBaseModel>>(
|
||||
m, "CpBaseModel", "Base class for the CP model.")
|
||||
.def(py::init<>())
|
||||
|
||||
Reference in New Issue
Block a user