From 374cc3e596ad6e10c8be51aa264c40778b07651f Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sun, 27 Jul 2025 15:29:20 -0700 Subject: [PATCH] polish code --- ortools/sat/cp_constraints.cc | 2 +- ortools/sat/cp_constraints.h | 2 +- ortools/sat/python/cp_model.py | 50 ++++++++---- ortools/sat/python/cp_model_helper.cc | 106 ++++++++++++++++++-------- 4 files changed, 114 insertions(+), 46 deletions(-) diff --git a/ortools/sat/cp_constraints.cc b/ortools/sat/cp_constraints.cc index 3dfa1dc8d8..341ca496bd 100644 --- a/ortools/sat/cp_constraints.cc +++ b/ortools/sat/cp_constraints.cc @@ -374,7 +374,7 @@ EnforcementStatus EnforcementPropagator::DebugStatus(EnforcementId id) { } BooleanXorPropagator::BooleanXorPropagator( - const std::vector& enforcement_literals, + absl::Span enforcement_literals, const std::vector& literals, bool value, Model* model) : literals_(literals), value_(value), diff --git a/ortools/sat/cp_constraints.h b/ortools/sat/cp_constraints.h index 49d6394dfa..4e8a296366 100644 --- a/ortools/sat/cp_constraints.h +++ b/ortools/sat/cp_constraints.h @@ -159,7 +159,7 @@ class EnforcementPropagator : public SatPropagator { // faster. class BooleanXorPropagator : public PropagatorInterface { public: - BooleanXorPropagator(const std::vector& enforcement_literals, + BooleanXorPropagator(absl::Span enforcement_literals, const std::vector& literals, bool value, Model* model); diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 12bc7bc866..7bc87255bd 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -184,7 +184,7 @@ _IndexOrSeries = Union[pd.Index, pd.Series] # Helper functions. def snake_case_to_camel_case(name: str) -> str: - """Converts a snake_case name to camelCase.""" + """Converts a snake_case name to CamelCase.""" words = name.split("_") return ( "".join(word.capitalize() for word in words) @@ -900,7 +900,9 @@ class CpModel(cmh.CpBaseModel): def add_bool_or(self, *literals): """Adds `Or(literals) == true`: sum(literals) >= 1.""" - return self._add_bool_argument_constraint("or", *literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_or, *literals + ) @overload def add_at_least_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -910,7 +912,9 @@ class CpModel(cmh.CpBaseModel): def add_at_least_one(self, *literals): """Same as `add_bool_or`: `sum(literals) >= 1`.""" - return self._add_bool_argument_constraint("or", *literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_or, *literals + ) @overload def add_at_most_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -920,7 +924,9 @@ class CpModel(cmh.CpBaseModel): def add_at_most_one(self, *literals) -> Constraint: """Adds `AtMostOne(literals)`: `sum(literals) <= 1`.""" - return self._add_bool_argument_constraint("at_most_one", *literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.at_most_one, *literals + ) @overload def add_exactly_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -930,7 +936,9 @@ class CpModel(cmh.CpBaseModel): def add_exactly_one(self, *literals): """Adds `ExactlyOne(literals)`: `sum(literals) == 1`.""" - return self._add_bool_argument_constraint("exactly_one", *literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.exactly_one, *literals + ) @overload def add_bool_and(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -940,7 +948,9 @@ class CpModel(cmh.CpBaseModel): def add_bool_and(self, *literals): """Adds `And(literals) == true`.""" - return self._add_bool_argument_constraint("and", *literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_and, *literals + ) @overload def add_bool_xor(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -960,7 +970,9 @@ class CpModel(cmh.CpBaseModel): Returns: An `Constraint` object. """ - return self._add_bool_argument_constraint("xor", *literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_xor, *literals + ) @overload def add_min_equality( @@ -974,7 +986,9 @@ class CpModel(cmh.CpBaseModel): def add_min_equality(self, target, *expressions) -> Constraint: """Adds `target == Min(expressions)`.""" - return self._add_linear_argument_constraint("min", target, *expressions) + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.min, target, *expressions + ) @overload def add_max_equality( @@ -988,17 +1002,23 @@ class CpModel(cmh.CpBaseModel): def add_max_equality(self, target, *expressions) -> Constraint: """Adds `target == Max(expressions)`.""" - return self._add_linear_argument_constraint("max", target, *expressions) + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.max, target, *expressions + ) def add_division_equality( self, target: LinearExprT, num: LinearExprT, denom: LinearExprT ) -> Constraint: """Adds `target == num // denom` (integer division rounded towards 0).""" - return self._add_linear_argument_constraint("div", target, [num, denom]) + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.div, target, [num, denom] + ) def add_abs_equality(self, target: LinearExprT, expr: LinearExprT) -> Constraint: """Adds `target == Abs(expr)`.""" - return self._add_linear_argument_constraint("max", target, [expr, -expr]) + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.max, target, [expr, -expr] + ) def add_modulo_equality( self, target: LinearExprT, expr: LinearExprT, mod: LinearExprT @@ -1022,7 +1042,9 @@ class CpModel(cmh.CpBaseModel): Returns: A `Constraint` object. """ - return self._add_linear_argument_constraint("mod", target, [expr, mod]) + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.mod, target, [expr, mod] + ) def add_multiplication_equality( self, @@ -1030,7 +1052,9 @@ class CpModel(cmh.CpBaseModel): *expressions: Union[Iterable[LinearExprT], LinearExprT], ) -> Constraint: """Adds `target == expressions[0] * .. * expressions[n]`.""" - return self._add_linear_argument_constraint("prod", target, *expressions) + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.prod, target, *expressions + ) # Scheduling support diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index aebea3da72..7bd52f513f 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -462,6 +462,22 @@ void LinearExprToProto(const py::handle& arg, int64_t multiplier, class Constraint; class IntervalVar; +enum class BoolArgumentConstraint { + kAtMostOne, + kBoolAnd, + kBoolOr, + kBoolXor, + kExactlyOne, +}; + +enum class LinearArgumentConstraint { + kDiv, + kMax, + kMin, + kMod, + kProd, +}; + class CpBaseModel : public std::enable_shared_from_this { public: CpBaseModel() @@ -573,7 +589,7 @@ class CpBaseModel : public std::enable_shared_from_this { const std::vector>& transition_triples); std::shared_ptr AddBoolArgumentConstraintInternal( - const std::string& name, py::args literals); + BoolArgumentConstraint type, py::args literals); std::shared_ptr AddBoundedLinearExpressionInternal( BoundedLinearExpression* ble); @@ -586,7 +602,7 @@ class CpBaseModel : public std::enable_shared_from_this { py::sequence inverse); std::shared_ptr AddLinearArgumentConstraintInternal( - const std::string& name, const py::handle& target, py::args exprs); + LinearArgumentConstraint type, const py::handle& target, py::args exprs); std::shared_ptr AddReservoirInternal(py::sequence times, py::sequence level_changes, @@ -702,23 +718,29 @@ std::shared_ptr CpBaseModel::AddAutomatonInternal( } std::shared_ptr CpBaseModel::AddBoolArgumentConstraintInternal( - const std::string& name, py::args literals) { + BoolArgumentConstraint type, py::args literals) { const int ct_index = model_proto_->constraints_size(); ConstraintProto* ct = model_proto_->add_constraints(); BoolArgumentProto* proto = nullptr; - if (name == "or") { - proto = ct->mutable_bool_or(); - } else if (name == "and") { - proto = ct->mutable_bool_and(); - } else if (name == "xor") { - proto = ct->mutable_bool_xor(); - } else if (name == "at_most_one") { - proto = ct->mutable_at_most_one(); - } else if (name == "exactly_one") { - proto = ct->mutable_exactly_one(); - } else { - ThrowError(PyExc_ValueError, - absl::StrCat("Unknown boolean argument constraint: ", name)); + switch (type) { + case BoolArgumentConstraint::kAtMostOne: + proto = ct->mutable_at_most_one(); + break; + case BoolArgumentConstraint::kBoolAnd: + proto = ct->mutable_bool_and(); + break; + case BoolArgumentConstraint::kBoolOr: + proto = ct->mutable_bool_or(); + break; + case BoolArgumentConstraint::kBoolXor: + proto = ct->mutable_bool_xor(); + break; + case BoolArgumentConstraint::kExactlyOne: + proto = ct->mutable_exactly_one(); + break; + default: + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown boolean argument constraint: ", type)); } if (literals.size() == 1 && py::isinstance(literals[0])) { for (const auto& literal : literals[0]) { @@ -781,25 +803,31 @@ std::shared_ptr CpBaseModel::AddInverseInternal( } std::shared_ptr CpBaseModel::AddLinearArgumentConstraintInternal( - const std::string& name, const py::handle& target, py::args exprs) { + LinearArgumentConstraint type, const py::handle& target, py::args exprs) { const int ct_index = model_proto_->constraints_size(); ConstraintProto* ct = model_proto_->add_constraints(); LinearArgumentProto* proto; int64_t multiplier = 1; - if (name == "min") { - proto = ct->mutable_lin_max(); - multiplier = -1; - } else if (name == "max") { - proto = ct->mutable_lin_max(); - } else if (name == "prod") { - proto = ct->mutable_int_prod(); - } else if (name == "div") { - proto = ct->mutable_int_div(); - } else if (name == "mod") { - proto = ct->mutable_int_mod(); - } else { - ThrowError(PyExc_ValueError, - absl::StrCat("Unknown integer argument constraint: ", name)); + switch (type) { + case LinearArgumentConstraint::kDiv: + proto = ct->mutable_int_div(); + break; + case LinearArgumentConstraint::kMax: + proto = ct->mutable_lin_max(); + break; + case LinearArgumentConstraint::kMin: + proto = ct->mutable_lin_max(); + multiplier = -1; + break; + case LinearArgumentConstraint::kMod: + proto = ct->mutable_int_mod(); + break; + case LinearArgumentConstraint::kProd: + proto = ct->mutable_int_prod(); + break; + default: + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown integer argument constraint: ", type)); } LinearExprToProto(target, multiplier, proto->mutable_target()); @@ -1871,6 +1899,22 @@ PYBIND11_MODULE(cp_model_helper, m) { return false; }); + py::enum_(m, "BoolArgumentConstraint") + .value("at_most_one", BoolArgumentConstraint::kAtMostOne) + .value("bool_and", BoolArgumentConstraint::kBoolAnd) + .value("bool_or", BoolArgumentConstraint::kBoolOr) + .value("bool_xor", BoolArgumentConstraint::kBoolXor) + .value("exactly_one", BoolArgumentConstraint::kExactlyOne) + .export_values(); + + py::enum_(m, "LinearArgumentConstraint") + .value("div", LinearArgumentConstraint::kDiv) + .value("max", LinearArgumentConstraint::kMax) + .value("min", LinearArgumentConstraint::kMin) + .value("mod", LinearArgumentConstraint::kMod) + .value("prod", LinearArgumentConstraint::kProd) + .export_values(); + py::class_>( m, "CpBaseModel", "Base class for the CP model.") .def(py::init<>())