polish code

This commit is contained in:
Laurent Perron
2025-07-27 15:29:20 -07:00
parent a0e25debc5
commit 374cc3e596
4 changed files with 114 additions and 46 deletions

View File

@@ -374,7 +374,7 @@ EnforcementStatus EnforcementPropagator::DebugStatus(EnforcementId id) {
}
BooleanXorPropagator::BooleanXorPropagator(
const std::vector<Literal>& enforcement_literals,
absl::Span<const Literal> enforcement_literals,
const std::vector<Literal>& literals, bool value, Model* model)
: literals_(literals),
value_(value),

View File

@@ -159,7 +159,7 @@ class EnforcementPropagator : public SatPropagator {
// faster.
class BooleanXorPropagator : public PropagatorInterface {
public:
BooleanXorPropagator(const std::vector<Literal>& enforcement_literals,
BooleanXorPropagator(absl::Span<const Literal> enforcement_literals,
const std::vector<Literal>& literals, bool value,
Model* model);

View File

@@ -184,7 +184,7 @@ _IndexOrSeries = Union[pd.Index, pd.Series]
# Helper functions.
def snake_case_to_camel_case(name: str) -> str:
"""Converts a snake_case name to camelCase."""
"""Converts a snake_case name to CamelCase."""
words = name.split("_")
return (
"".join(word.capitalize() for word in words)
@@ -900,7 +900,9 @@ class CpModel(cmh.CpBaseModel):
def add_bool_or(self, *literals):
"""Adds `Or(literals) == true`: sum(literals) >= 1."""
return self._add_bool_argument_constraint("or", *literals)
return self._add_bool_argument_constraint(
cmh.BoolArgumentConstraint.bool_or, *literals
)
@overload
def add_at_least_one(self, literals: Iterable[LiteralT]) -> Constraint: ...
@@ -910,7 +912,9 @@ class CpModel(cmh.CpBaseModel):
def add_at_least_one(self, *literals):
"""Same as `add_bool_or`: `sum(literals) >= 1`."""
return self._add_bool_argument_constraint("or", *literals)
return self._add_bool_argument_constraint(
cmh.BoolArgumentConstraint.bool_or, *literals
)
@overload
def add_at_most_one(self, literals: Iterable[LiteralT]) -> Constraint: ...
@@ -920,7 +924,9 @@ class CpModel(cmh.CpBaseModel):
def add_at_most_one(self, *literals) -> Constraint:
"""Adds `AtMostOne(literals)`: `sum(literals) <= 1`."""
return self._add_bool_argument_constraint("at_most_one", *literals)
return self._add_bool_argument_constraint(
cmh.BoolArgumentConstraint.at_most_one, *literals
)
@overload
def add_exactly_one(self, literals: Iterable[LiteralT]) -> Constraint: ...
@@ -930,7 +936,9 @@ class CpModel(cmh.CpBaseModel):
def add_exactly_one(self, *literals):
"""Adds `ExactlyOne(literals)`: `sum(literals) == 1`."""
return self._add_bool_argument_constraint("exactly_one", *literals)
return self._add_bool_argument_constraint(
cmh.BoolArgumentConstraint.exactly_one, *literals
)
@overload
def add_bool_and(self, literals: Iterable[LiteralT]) -> Constraint: ...
@@ -940,7 +948,9 @@ class CpModel(cmh.CpBaseModel):
def add_bool_and(self, *literals):
"""Adds `And(literals) == true`."""
return self._add_bool_argument_constraint("and", *literals)
return self._add_bool_argument_constraint(
cmh.BoolArgumentConstraint.bool_and, *literals
)
@overload
def add_bool_xor(self, literals: Iterable[LiteralT]) -> Constraint: ...
@@ -960,7 +970,9 @@ class CpModel(cmh.CpBaseModel):
Returns:
An `Constraint` object.
"""
return self._add_bool_argument_constraint("xor", *literals)
return self._add_bool_argument_constraint(
cmh.BoolArgumentConstraint.bool_xor, *literals
)
@overload
def add_min_equality(
@@ -974,7 +986,9 @@ class CpModel(cmh.CpBaseModel):
def add_min_equality(self, target, *expressions) -> Constraint:
"""Adds `target == Min(expressions)`."""
return self._add_linear_argument_constraint("min", target, *expressions)
return self._add_linear_argument_constraint(
cmh.LinearArgumentConstraint.min, target, *expressions
)
@overload
def add_max_equality(
@@ -988,17 +1002,23 @@ class CpModel(cmh.CpBaseModel):
def add_max_equality(self, target, *expressions) -> Constraint:
"""Adds `target == Max(expressions)`."""
return self._add_linear_argument_constraint("max", target, *expressions)
return self._add_linear_argument_constraint(
cmh.LinearArgumentConstraint.max, target, *expressions
)
def add_division_equality(
self, target: LinearExprT, num: LinearExprT, denom: LinearExprT
) -> Constraint:
"""Adds `target == num // denom` (integer division rounded towards 0)."""
return self._add_linear_argument_constraint("div", target, [num, denom])
return self._add_linear_argument_constraint(
cmh.LinearArgumentConstraint.div, target, [num, denom]
)
def add_abs_equality(self, target: LinearExprT, expr: LinearExprT) -> Constraint:
"""Adds `target == Abs(expr)`."""
return self._add_linear_argument_constraint("max", target, [expr, -expr])
return self._add_linear_argument_constraint(
cmh.LinearArgumentConstraint.max, target, [expr, -expr]
)
def add_modulo_equality(
self, target: LinearExprT, expr: LinearExprT, mod: LinearExprT
@@ -1022,7 +1042,9 @@ class CpModel(cmh.CpBaseModel):
Returns:
A `Constraint` object.
"""
return self._add_linear_argument_constraint("mod", target, [expr, mod])
return self._add_linear_argument_constraint(
cmh.LinearArgumentConstraint.mod, target, [expr, mod]
)
def add_multiplication_equality(
self,
@@ -1030,7 +1052,9 @@ class CpModel(cmh.CpBaseModel):
*expressions: Union[Iterable[LinearExprT], LinearExprT],
) -> Constraint:
"""Adds `target == expressions[0] * .. * expressions[n]`."""
return self._add_linear_argument_constraint("prod", target, *expressions)
return self._add_linear_argument_constraint(
cmh.LinearArgumentConstraint.prod, target, *expressions
)
# Scheduling support

View File

@@ -462,6 +462,22 @@ void LinearExprToProto(const py::handle& arg, int64_t multiplier,
class Constraint;
class IntervalVar;
enum class BoolArgumentConstraint {
kAtMostOne,
kBoolAnd,
kBoolOr,
kBoolXor,
kExactlyOne,
};
enum class LinearArgumentConstraint {
kDiv,
kMax,
kMin,
kMod,
kProd,
};
class CpBaseModel : public std::enable_shared_from_this<CpBaseModel> {
public:
CpBaseModel()
@@ -573,7 +589,7 @@ class CpBaseModel : public std::enable_shared_from_this<CpBaseModel> {
const std::vector<std::vector<int64_t>>& transition_triples);
std::shared_ptr<Constraint> AddBoolArgumentConstraintInternal(
const std::string& name, py::args literals);
BoolArgumentConstraint type, py::args literals);
std::shared_ptr<Constraint> AddBoundedLinearExpressionInternal(
BoundedLinearExpression* ble);
@@ -586,7 +602,7 @@ class CpBaseModel : public std::enable_shared_from_this<CpBaseModel> {
py::sequence inverse);
std::shared_ptr<Constraint> AddLinearArgumentConstraintInternal(
const std::string& name, const py::handle& target, py::args exprs);
LinearArgumentConstraint type, const py::handle& target, py::args exprs);
std::shared_ptr<Constraint> AddReservoirInternal(py::sequence times,
py::sequence level_changes,
@@ -702,23 +718,29 @@ std::shared_ptr<Constraint> CpBaseModel::AddAutomatonInternal(
}
std::shared_ptr<Constraint> CpBaseModel::AddBoolArgumentConstraintInternal(
const std::string& name, py::args literals) {
BoolArgumentConstraint type, py::args literals) {
const int ct_index = model_proto_->constraints_size();
ConstraintProto* ct = model_proto_->add_constraints();
BoolArgumentProto* proto = nullptr;
if (name == "or") {
proto = ct->mutable_bool_or();
} else if (name == "and") {
proto = ct->mutable_bool_and();
} else if (name == "xor") {
proto = ct->mutable_bool_xor();
} else if (name == "at_most_one") {
proto = ct->mutable_at_most_one();
} else if (name == "exactly_one") {
proto = ct->mutable_exactly_one();
} else {
ThrowError(PyExc_ValueError,
absl::StrCat("Unknown boolean argument constraint: ", name));
switch (type) {
case BoolArgumentConstraint::kAtMostOne:
proto = ct->mutable_at_most_one();
break;
case BoolArgumentConstraint::kBoolAnd:
proto = ct->mutable_bool_and();
break;
case BoolArgumentConstraint::kBoolOr:
proto = ct->mutable_bool_or();
break;
case BoolArgumentConstraint::kBoolXor:
proto = ct->mutable_bool_xor();
break;
case BoolArgumentConstraint::kExactlyOne:
proto = ct->mutable_exactly_one();
break;
default:
ThrowError(PyExc_ValueError,
absl::StrCat("Unknown boolean argument constraint: ", type));
}
if (literals.size() == 1 && py::isinstance<py::iterable>(literals[0])) {
for (const auto& literal : literals[0]) {
@@ -781,25 +803,31 @@ std::shared_ptr<Constraint> CpBaseModel::AddInverseInternal(
}
std::shared_ptr<Constraint> CpBaseModel::AddLinearArgumentConstraintInternal(
const std::string& name, const py::handle& target, py::args exprs) {
LinearArgumentConstraint type, const py::handle& target, py::args exprs) {
const int ct_index = model_proto_->constraints_size();
ConstraintProto* ct = model_proto_->add_constraints();
LinearArgumentProto* proto;
int64_t multiplier = 1;
if (name == "min") {
proto = ct->mutable_lin_max();
multiplier = -1;
} else if (name == "max") {
proto = ct->mutable_lin_max();
} else if (name == "prod") {
proto = ct->mutable_int_prod();
} else if (name == "div") {
proto = ct->mutable_int_div();
} else if (name == "mod") {
proto = ct->mutable_int_mod();
} else {
ThrowError(PyExc_ValueError,
absl::StrCat("Unknown integer argument constraint: ", name));
switch (type) {
case LinearArgumentConstraint::kDiv:
proto = ct->mutable_int_div();
break;
case LinearArgumentConstraint::kMax:
proto = ct->mutable_lin_max();
break;
case LinearArgumentConstraint::kMin:
proto = ct->mutable_lin_max();
multiplier = -1;
break;
case LinearArgumentConstraint::kMod:
proto = ct->mutable_int_mod();
break;
case LinearArgumentConstraint::kProd:
proto = ct->mutable_int_prod();
break;
default:
ThrowError(PyExc_ValueError,
absl::StrCat("Unknown integer argument constraint: ", type));
}
LinearExprToProto(target, multiplier, proto->mutable_target());
@@ -1871,6 +1899,22 @@ PYBIND11_MODULE(cp_model_helper, m) {
return false;
});
py::enum_<BoolArgumentConstraint>(m, "BoolArgumentConstraint")
.value("at_most_one", BoolArgumentConstraint::kAtMostOne)
.value("bool_and", BoolArgumentConstraint::kBoolAnd)
.value("bool_or", BoolArgumentConstraint::kBoolOr)
.value("bool_xor", BoolArgumentConstraint::kBoolXor)
.value("exactly_one", BoolArgumentConstraint::kExactlyOne)
.export_values();
py::enum_<LinearArgumentConstraint>(m, "LinearArgumentConstraint")
.value("div", LinearArgumentConstraint::kDiv)
.value("max", LinearArgumentConstraint::kMax)
.value("min", LinearArgumentConstraint::kMin)
.value("mod", LinearArgumentConstraint::kMod)
.value("prod", LinearArgumentConstraint::kProd)
.export_values();
py::class_<CpBaseModel, std::shared_ptr<CpBaseModel>>(
m, "CpBaseModel", "Base class for the CP model.")
.def(py::init<>())