From 79f2c45c335ae3043f008c8b0cc81a47ad5d11a6 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sat, 1 Jan 2022 19:26:39 +0100 Subject: [PATCH] [CP-SAT] Use AddExactly/AtMostOne in examples/samples; add int_square presolve; add multiplication constraint with target = left * right --- examples/cpp/binpacking_2d_sat.cc | 3 +- examples/python/bus_driver_scheduling_sat.py | 14 ++++----- examples/python/flexible_job_shop_sat.py | 2 +- examples/python/knapsack_2d_sat.py | 2 +- examples/python/rcpsp_sat.py | 2 +- examples/python/steel_mill_slab_sat.py | 4 +-- .../java/com/google/ortools/sat/CpModel.java | 11 +++++++ ortools/sat/BUILD.bazel | 2 ++ ortools/sat/cp_model.cc | 10 +++++++ ortools/sat/cp_model.h | 5 ++++ ortools/sat/cp_model_presolve.cc | 29 +++++++++++-------- ortools/sat/csharp/CpModel.cs | 11 +++++++ ortools/sat/doc/integer_arithmetic.md | 4 +-- ortools/sat/doc/scheduling.md | 4 +-- ortools/sat/integer_expr.cc | 29 ++----------------- ortools/sat/python/cp_model.py | 2 +- ortools/sat/samples/assignment_groups_sat.py | 9 +++--- ortools/sat/samples/assignment_sat.cc | 12 ++++---- ortools/sat/samples/assignment_sat.py | 4 +-- .../sat/samples/assignment_task_sizes_sat.py | 2 +- ortools/sat/samples/assignment_teams_sat.py | 4 +-- ortools/sat/samples/multiple_knapsack_sat.cc | 6 ++-- ortools/sat/samples/multiple_knapsack_sat.py | 2 +- ortools/sat/samples/nurses_sat.cc | 12 ++++---- ortools/sat/samples/nurses_sat.py | 4 +-- .../overlapping_intervals_sample_sat.py | 4 +-- ortools/sat/samples/schedule_requests_sat.cc | 12 ++++---- ortools/sat/samples/schedule_requests_sat.py | 4 +-- .../sat/samples/step_function_sample_sat.cc | 2 +- .../sat/samples/step_function_sample_sat.py | 2 +- ortools/sat/util.cc | 21 ++++++++++++++ ortools/sat/util.h | 4 +++ ortools/util/sorted_interval_list.cc | 27 +++++++++++++++++ ortools/util/sorted_interval_list.h | 5 ++++ 34 files changed, 174 insertions(+), 96 deletions(-) diff --git a/examples/cpp/binpacking_2d_sat.cc b/examples/cpp/binpacking_2d_sat.cc index c47a259b95..0b9b6272f5 100644 --- a/examples/cpp/binpacking_2d_sat.cc +++ b/examples/cpp/binpacking_2d_sat.cc @@ -164,8 +164,7 @@ void LoadAndSolve(const std::string& file_name, int instance) { cp_model.AddImplication(item_to_bin[item][b], bin_is_used[b]); all_items_in_bin.push_back(item_to_bin[item][b]); } - all_items_in_bin.push_back(bin_is_used[b].Not()); - cp_model.AddBoolOr(all_items_in_bin); + cp_model.AddBoolOr(all_items_in_bin).OnlyEnforceIf(bin_is_used[b]); } // Symmetry breaking. diff --git a/examples/python/bus_driver_scheduling_sat.py b/examples/python/bus_driver_scheduling_sat.py index 76766016c7..2aa69d154b 100644 --- a/examples/python/bus_driver_scheduling_sat.py +++ b/examples/python/bus_driver_scheduling_sat.py @@ -1889,18 +1889,18 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): model.Add(working_times[d] >= min_working_time) # Create circuit constraint. - model.Add(sum(outgoing_source_literals) == 1) + model.AddExactlyOne(outgoing_source_literals) for s in range(num_shifts): - model.Add(sum(outgoing_literals[s]) == 1) - model.Add(sum(incoming_literals[s]) == 1) - model.Add(sum(incoming_sink_literals) == 1) + model.AddExactlyOne(outgoing_literals[s]) + model.AddExactlyOne(incoming_literals[s]) + model.AddExactlyOne(incoming_sink_literals) # Each shift is covered. for s in range(num_shifts): - model.Add(sum(performed[d, s] for d in range(num_drivers)) == 1) + model.AddExactlyOne([performed[d, s] for d in range(num_drivers)]) # Globally, each node has one incoming and one outgoing literal - model.Add(sum(shared_incoming_literals[s]) == 1) - model.Add(sum(shared_outgoing_literals[s]) == 1) + model.AddExactlyOne(shared_incoming_literals[s]) + model.AddExactlyOne(shared_outgoing_literals[s]) # Symmetry breaking diff --git a/examples/python/flexible_job_shop_sat.py b/examples/python/flexible_job_shop_sat.py index f7c8df4f8c..185c4201d3 100644 --- a/examples/python/flexible_job_shop_sat.py +++ b/examples/python/flexible_job_shop_sat.py @@ -152,7 +152,7 @@ def flexible_jobshop(): presences[(job_id, task_id, alt_id)] = l_presence # Select exactly one presence variable. - model.Add(sum(l_presences) == 1) + model.AddExactlyOne(l_presences) else: intervals_per_resources[task[0][1]].append(interval) presences[(job_id, task_id, 0)] = model.NewConstant(1) diff --git a/examples/python/knapsack_2d_sat.py b/examples/python/knapsack_2d_sat.py index 724df40836..cfc21dc531 100644 --- a/examples/python/knapsack_2d_sat.py +++ b/examples/python/knapsack_2d_sat.py @@ -312,7 +312,7 @@ def solve_with_rotations(data, max_height, max_width): rotated = model.NewBoolVar(f'rotated_{i}') ### Exactly one state must be chosen. - model.Add(not_selected + no_rotation + rotated == 1) + model.AddExactlyOne([not_selected, no_rotation, rotated]) ### Define height and width according to the state. dim1 = item_widths[i] diff --git a/examples/python/rcpsp_sat.py b/examples/python/rcpsp_sat.py index 6cb4408483..b714adfaac 100644 --- a/examples/python/rcpsp_sat.py +++ b/examples/python/rcpsp_sat.py @@ -139,7 +139,7 @@ def SolveRcpsp(problem, proto_file, params): ] # Exactly one recipe must be performed. - model.Add(cp_model.LinearExpr.Sum(literals) == 1) + model.AddExactlyOne(literals) else: literals = [1] diff --git a/examples/python/steel_mill_slab_sat.py b/examples/python/steel_mill_slab_sat.py index 0d1457e493..3646aba573 100644 --- a/examples/python/steel_mill_slab_sat.py +++ b/examples/python/steel_mill_slab_sat.py @@ -340,7 +340,7 @@ def steel_mill_slab(problem, break_symmetries): # Orders are assigned to one slab. for o in all_orders: - model.Add(sum(assign[o]) == 1) + model.AddExactlyOne(assign[o]) # Redundant constraint (sum of loads == sum of widths). model.Add(sum(loads) == sum(widths)) @@ -523,7 +523,7 @@ def steel_mill_slab_with_valid_slabs(problem, break_symmetries): # Orders are assigned to one slab. for o in all_orders: - model.Add(sum(assign[o]) == 1) + model.AddExactlyOne(assign[o]) # Redundant constraint (sum of loads == sum of widths). model.Add(sum(loads) == sum(widths)) diff --git a/ortools/java/com/google/ortools/sat/CpModel.java b/ortools/java/com/google/ortools/sat/CpModel.java index 855fb97ceb..85f2ac9d91 100644 --- a/ortools/java/com/google/ortools/sat/CpModel.java +++ b/ortools/java/com/google/ortools/sat/CpModel.java @@ -649,6 +649,17 @@ public final class CpModel { return ct; } + /** Adds {@code target == left * right}. */ + public Constraint addMultiplicationEquality( + LinearExpr target, LinearExpr left, LinearExpr right) { + Constraint ct = new Constraint(modelBuilder); + LinearArgumentProto.Builder intProd = ct.getBuilder().getIntProdBuilder(); + intProd.setTarget(getLinearExpressionProtoBuilderFromLinearExpr(target, /*negate=*/false)); + intProd.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(left, /*negate=*/false)); + intProd.addExprs(getLinearExpressionProtoBuilderFromLinearExpr(right, /*negate=*/false)); + return ct; + } + // Scheduling support. /** diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index fb0ea15f39..5b2b18b1f6 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -801,6 +801,7 @@ cc_library( ":precedences", ":sat_base", ":sat_solver", + ":util", "//ortools/base", "//ortools/base:int_type", "//ortools/base:stl_util", @@ -1182,6 +1183,7 @@ cc_library( "//ortools/base", "//ortools/base:stl_util", "//ortools/util:random_engine", + "//ortools/util:saturated_arithmetic", "//ortools/util:time_limit", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/numeric:int128", diff --git a/ortools/sat/cp_model.cc b/ortools/sat/cp_model.cc index 4cdf85c8c9..542b74745c 100644 --- a/ortools/sat/cp_model.cc +++ b/ortools/sat/cp_model.cc @@ -1170,6 +1170,16 @@ Constraint CpModelBuilder::AddMultiplicationEquality( } return Constraint(proto); } +Constraint CpModelBuilder::AddMultiplicationEquality(const LinearExpr& target, + const LinearExpr& left, + const LinearExpr& right) { + ConstraintProto* const proto = cp_model_.add_constraints(); + *proto->mutable_int_prod()->mutable_target() = LinearExprToProto(target); + *proto->mutable_int_prod()->add_exprs() = LinearExprToProto(left); + *proto->mutable_int_prod()->add_exprs() = LinearExprToProto(right); + + return Constraint(proto); +} Constraint CpModelBuilder::AddNoOverlap(absl::Span vars) { ConstraintProto* const proto = cp_model_.add_constraints(); diff --git a/ortools/sat/cp_model.h b/ortools/sat/cp_model.h index 9cc148509f..54f8be57d1 100644 --- a/ortools/sat/cp_model.h +++ b/ortools/sat/cp_model.h @@ -1005,6 +1005,11 @@ class CpModelBuilder { Constraint AddMultiplicationEquality(const LinearExpr& target, std::initializer_list exprs); + /// Adds target == left * right. + Constraint AddMultiplicationEquality(const LinearExpr& target, + const LinearExpr& left, + const LinearExpr& right); + /** * Adds a no-overlap constraint that ensures that all present intervals do * not overlap in time. diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 2463ee6143..c84d7e2677 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -1101,18 +1101,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { LinearExpressionProtosAreEqual(ct->int_prod().exprs(0), ct->int_prod().exprs(1))) { is_square = true; - const Domain domain = context_->DomainSuperSetOf(ct->int_prod().exprs(0)); - if (domain.Size() < 50) { - // Exact computation - std::vector values; - for (const int64_t value : domain.Values()) { - values.push_back(value * value); - } - implied = Domain::FromValues(values); - } else { - implied = domain.ContinuousMultiplicationBy(domain).IntersectionWith( - {0, std::numeric_limits::max()}); - } + implied = + context_->DomainSuperSetOf(ct->int_prod().exprs(0)).SquareSuperset(); } else { for (const LinearExpressionProto& expr : ct->int_prod().exprs()) { implied = @@ -1129,6 +1119,21 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { is_square ? "int_square" : "int_prod", ": reduced target domain.")); } + // y = x * x, we can reduce the domain of x from the domain of y. + if (is_square) { + const int64_t target_max = context_->MaxOf(ct->int_prod().target()); + DCHECK_GE(target_max, 0); + const int64_t sqrt_max = FloorSquareRoot(target_max); + bool expr_reduced = false; + if (!context_->IntersectDomainWith(ct->int_prod().exprs(0), + {-sqrt_max, sqrt_max}, &expr_reduced)) { + return false; + } + if (expr_reduced) { + context_->UpdateRuleStats("int_square: reduced expr domain."); + } + } + if (ct->int_prod().exprs_size() == 2) { LinearExpressionProto a = ct->int_prod().exprs(0); LinearExpressionProto b = ct->int_prod().exprs(1); diff --git a/ortools/sat/csharp/CpModel.cs b/ortools/sat/csharp/CpModel.cs index 8d3ab90055..8060d37237 100644 --- a/ortools/sat/csharp/CpModel.cs +++ b/ortools/sat/csharp/CpModel.cs @@ -611,6 +611,17 @@ public class CpModel return ct; } + public Constraint AddMultiplicationEquality(LinearExpr target, LinearExpr left, LinearExpr right) + { + Constraint ct = new Constraint(model_); + LinearArgumentProto args = new LinearArgumentProto(); + args.Target = GetLinearExpressionProto(target); + args.Exprs.Add(GetLinearExpressionProto(left)); + args.Exprs.Add(GetLinearExpressionProto(right)); + ct.Proto.IntProd = args; + return ct; + } + public Constraint AddProdEquality(IntVar target, IEnumerable vars) { return AddMultiplicationEquality(target, vars); diff --git a/ortools/sat/doc/integer_arithmetic.md b/ortools/sat/doc/integer_arithmetic.md index d733de17ad..f5de7d82fd 100644 --- a/ortools/sat/doc/integer_arithmetic.md +++ b/ortools/sat/doc/integer_arithmetic.md @@ -736,7 +736,7 @@ def step_function_sample_sat(): model.Add(x == 7).OnlyEnforceIf(b3) model.Add(expr == 3).OnlyEnforceIf(b3) - # At least one bi is true. (we could use a sum == 1). + # At least one bi is true. (we could use an exactly one constraint). model.AddBoolOr([b0, b2, b3]) # Search for x values in increasing order. @@ -805,7 +805,7 @@ void StepFunctionSampleSat() { cp_model.AddEquality(x, 7).OnlyEnforceIf(b3); cp_model.AddEquality(expr, 3).OnlyEnforceIf(b3); - // At least one bi is true. (we could use a sum == 1). + // At least one bi is true. (we could use an exactly one constraint). cp_model.AddBoolOr({b0, b2, b3}); // Search for x values in increasing order. diff --git a/ortools/sat/doc/scheduling.md b/ortools/sat/doc/scheduling.md index 619737e157..e62e2dc7a2 100644 --- a/ortools/sat/doc/scheduling.md +++ b/ortools/sat/doc/scheduling.md @@ -1507,8 +1507,8 @@ def OverlappingIntervals(): model.AddImplication(a_after_b, a_overlaps_b.Not()) model.AddImplication(b_after_a, a_overlaps_b.Not()) - # Option b: using a sum() == 1. - # model.Add(a_after_b + b_after_a + a_overlaps_b == 1) + # Option b: using an exactly one constraint. + # model.AddExactlyOne([a_after_b, b_after_a, a_overlaps_b]) # Search for start values in increasing order for the two intervals. model.AddDecisionStrategy([start_var_a, start_var_b], cp_model.CHOOSE_FIRST, diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index a907158761..735f1ae8d0 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -22,6 +22,7 @@ #include "absl/memory/memory.h" #include "ortools/base/stl_util.h" #include "ortools/sat/integer.h" +#include "ortools/sat/util.h" #include "ortools/util/sorted_interval_list.h" #include "ortools/util/time_limit.h" @@ -964,30 +965,6 @@ void ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); } -namespace { - -// TODO(user): Find better implementation? In pratice passing via double is -// almost always correct, but the CapProd() might be a bit slow. However this -// is only called when we do propagate something. -IntegerValue FloorSquareRoot(IntegerValue a) { - IntegerValue result(static_cast( - std::floor(std::sqrt(static_cast(a.value()))))); - while (CapProd(result.value(), result.value()) > a) --result; - while (CapProd(result.value() + 1, result.value() + 1) <= a) ++result; - return result; -} - -// TODO(user): Find better implementation? -IntegerValue CeilSquareRoot(IntegerValue a) { - IntegerValue result(static_cast( - std::ceil(std::sqrt(static_cast(a.value()))))); - while (CapProd(result.value(), result.value()) < a) ++result; - while ((result.value() - 1) * (result.value() - 1) >= a) --result; - return result; -} - -} // namespace - SquarePropagator::SquarePropagator(AffineExpression x, AffineExpression s, IntegerTrail* integer_trail) : x_(x), s_(s), integer_trail_(integer_trail) { @@ -1006,7 +983,7 @@ bool SquarePropagator::Propagate() { return false; } } else if (min_x_square < min_s) { - const IntegerValue new_min = CeilSquareRoot(min_s); + const IntegerValue new_min(CeilSquareRoot(min_s.value())); if (!integer_trail_->SafeEnqueue( x_.GreaterOrEqual(new_min), {s_.GreaterOrEqual((new_min - 1) * (new_min - 1) + 1)})) { @@ -1023,7 +1000,7 @@ bool SquarePropagator::Propagate() { return false; } } else if (max_x_square > max_s) { - const IntegerValue new_max = FloorSquareRoot(max_s); + const IntegerValue new_max(FloorSquareRoot(max_s.value())); if (!integer_trail_->SafeEnqueue( x_.LowerOrEqual(new_max), {s_.LowerOrEqual(IntegerValue(CapProd(new_max.value() + 1, diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 7b0a7bb7de..7b3193ced6 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -1597,7 +1597,7 @@ class CpModel(object): return ct def AddMultiplicationEquality(self, target, expressions): - """Adds `target == variables[0] * .. * variables[n]`.""" + """Adds `target == expressions[0] * .. * expressions[n]`.""" ct = Constraint(self.__model.constraints) model_ct = self.__model.constraints[ct.Index()] model_ct.int_prod.exprs.extend( diff --git a/ortools/sat/samples/assignment_groups_sat.py b/ortools/sat/samples/assignment_groups_sat.py index 399a8ed6d3..1d79cb606d 100755 --- a/ortools/sat/samples/assignment_groups_sat.py +++ b/ortools/sat/samples/assignment_groups_sat.py @@ -83,11 +83,11 @@ def main(): # [START constraints] # Each worker is assigned to at most one task. for worker in range(num_workers): - model.Add(sum(x[worker, task] for task in range(num_tasks)) <= 1) + model.AddAtMostOne([x[worker, task] for task in range(num_tasks)]) # Each task is assigned to exactly one worker. for task in range(num_tasks): - model.Add(sum(x[worker, task] for worker in range(num_workers)) == 1) + model.AddExactlyOne([x[worker, task] for worker in range(num_workers)]) # [END constraints] # [START assignments] @@ -97,8 +97,9 @@ def main(): work[worker] = model.NewBoolVar(f'work[{worker}]') for worker in range(num_workers): - model.Add(work[worker] == sum( - x[worker, task] for task in range(num_tasks))) + for task in range(num_tasks): + model.Add(work[worker] == sum( + x[worker, task] for task in range(num_tasks))) # Define the allowed groups of worders model.AddAllowedAssignments([work[0], work[1], work[2], work[3]], group1) diff --git a/ortools/sat/samples/assignment_sat.cc b/ortools/sat/samples/assignment_sat.cc index 02078ac40d..ff3e085b96 100644 --- a/ortools/sat/samples/assignment_sat.cc +++ b/ortools/sat/samples/assignment_sat.cc @@ -25,8 +25,8 @@ void IntegerProgrammingExample() { {90, 80, 75, 70}, {35, 85, 55, 65}, {125, 95, 90, 95}, {45, 110, 95, 115}, {50, 100, 90, 100}, }; - const int num_workers = costs.size(); - const int num_tasks = costs[0].size(); + const int num_workers = static_cast(costs.size()); + const int num_tasks = static_cast(costs[0].size()); // [END data_model] // Model @@ -51,15 +51,15 @@ void IntegerProgrammingExample() { // [START constraints] // Each worker is assigned to at most one task. for (int i = 0; i < num_workers; ++i) { - cp_model.AddLessOrEqual(LinearExpr::Sum(x[i]), 1); + cp_model.AddAtMostOne(x[i]); } // Each task is assigned to exactly one worker. for (int j = 0; j < num_tasks; ++j) { - LinearExpr task_sum; + std::vector tasks; for (int i = 0; i < num_workers; ++i) { - task_sum.AddTerm(x[i][j], 1); + tasks.push_back(x[i][j]); } - cp_model.AddEquality(task_sum, 1); + cp_model.AddExactlyOne(tasks); } // [END constraints] diff --git a/ortools/sat/samples/assignment_sat.py b/ortools/sat/samples/assignment_sat.py index 0045ab855d..728bf1ad81 100755 --- a/ortools/sat/samples/assignment_sat.py +++ b/ortools/sat/samples/assignment_sat.py @@ -51,11 +51,11 @@ def main(): # [START constraints] # Each worker is assigned to at most one task. for i in range(num_workers): - model.Add(sum(x[i][j] for j in range(num_tasks)) <= 1) + model.AddAtMostOne([x[i][j] for j in range(num_tasks)]) # Each task is assigned to exactly one worker. for j in range(num_tasks): - model.Add(sum(x[i][j] for i in range(num_workers)) == 1) + model.AddExactlyOne([x[i][j] for i in range(num_workers)]) # [END constraints] # Objective diff --git a/ortools/sat/samples/assignment_task_sizes_sat.py b/ortools/sat/samples/assignment_task_sizes_sat.py index ef42d82af1..aea8f44585 100755 --- a/ortools/sat/samples/assignment_task_sizes_sat.py +++ b/ortools/sat/samples/assignment_task_sizes_sat.py @@ -64,7 +64,7 @@ def main(): # Each task is assigned to exactly one worker. for task in range(num_tasks): - model.Add(sum(x[worker, task] for worker in range(num_workers)) == 1) + model.AddExactlyOne([x[worker, task] for worker in range(num_workers)]) # [END constraints] # Objective diff --git a/ortools/sat/samples/assignment_teams_sat.py b/ortools/sat/samples/assignment_teams_sat.py index 3e05e444b5..f93d5ce32c 100755 --- a/ortools/sat/samples/assignment_teams_sat.py +++ b/ortools/sat/samples/assignment_teams_sat.py @@ -55,11 +55,11 @@ def main(): # [START constraints] # Each worker is assigned to at most one task. for worker in range(num_workers): - model.Add(sum(x[worker, task] for task in range(num_tasks)) <= 1) + model.AddAtMostOne([x[worker, task] for task in range(num_tasks)]) # Each task is assigned to exactly one worker. for task in range(num_tasks): - model.Add(sum(x[worker, task] for worker in range(num_workers)) == 1) + model.AddExactlyOne([x[worker, task] for worker in range(num_workers)]) # Each team takes at most two tasks. team1_tasks = [] diff --git a/ortools/sat/samples/multiple_knapsack_sat.cc b/ortools/sat/samples/multiple_knapsack_sat.cc index eb5fe08ffa..d0f5523826 100644 --- a/ortools/sat/samples/multiple_knapsack_sat.cc +++ b/ortools/sat/samples/multiple_knapsack_sat.cc @@ -62,11 +62,11 @@ void MultipleKnapsackSat() { // [START constraints] // Each item is assigned to at most one bin. for (int i : all_items) { - LinearExpr expr; + std::vector copies; for (int b : all_bins) { - expr += x[std::make_tuple(i, b)]; + copies.push_back(x[std::make_tuple(i, b)]); } - cp_model.AddLessOrEqual(expr, 1); + cp_model.AddAtMostOne(copies); } // The amount packed in each bin cannot exceed its capacity. diff --git a/ortools/sat/samples/multiple_knapsack_sat.py b/ortools/sat/samples/multiple_knapsack_sat.py index 36166203b3..7212a177c6 100755 --- a/ortools/sat/samples/multiple_knapsack_sat.py +++ b/ortools/sat/samples/multiple_knapsack_sat.py @@ -53,7 +53,7 @@ def main(): # [START constraints] # Each item is assigned to at most one bin. for i in data['all_items']: - model.Add(sum(x[i, b] for b in data['all_bins']) <= 1) + model.AddAtMostOne([x[i, b] for b in data['all_bins']]) # The amount packed in each bin cannot exceed its capacity. for b in data['all_bins']: diff --git a/ortools/sat/samples/nurses_sat.cc b/ortools/sat/samples/nurses_sat.cc index a63277af95..61780051ce 100644 --- a/ortools/sat/samples/nurses_sat.cc +++ b/ortools/sat/samples/nurses_sat.cc @@ -70,12 +70,12 @@ void NurseSat() { // [START exactly_one_nurse] for (int d : all_days) { for (int s : all_shifts) { - LinearExpr sum; + std::vector nurses; for (int n : all_nurses) { auto key = std::make_tuple(n, d, s); - sum += shifts[key]; + nurses.push_back(shifts[key]); } - cp_model.AddEquality(sum, 1); + cp_model.AddExactlyOne(nurses); } } // [END exactly_one_nurse] @@ -84,12 +84,12 @@ void NurseSat() { // [START at_most_one_shift] for (int n : all_nurses) { for (int d : all_days) { - LinearExpr sum; + std::vector work; for (int s : all_shifts) { auto key = std::make_tuple(n, d, s); - sum += shifts[key]; + work.push_back(shifts[key]); } - cp_model.AddLessOrEqual(sum, 1); + cp_model.AddAtMostOne(work); } } // [END at_most_one_shift] diff --git a/ortools/sat/samples/nurses_sat.py b/ortools/sat/samples/nurses_sat.py index 8aed06e40a..af572f4ace 100755 --- a/ortools/sat/samples/nurses_sat.py +++ b/ortools/sat/samples/nurses_sat.py @@ -49,14 +49,14 @@ def main(): # [START exactly_one_nurse] for d in all_days: for s in all_shifts: - model.Add(sum(shifts[(n, d, s)] for n in all_nurses) == 1) + model.AddExactlyOne([shifts[(n, d, s)] for n in all_nurses]) # [END exactly_one_nurse] # Each nurse works at most one shift per day. # [START at_most_one_shift] for n in all_nurses: for d in all_days: - model.Add(sum(shifts[(n, d, s)] for s in all_shifts) <= 1) + model.AddAtMostOne([shifts[(n, d, s)] for s in all_shifts]) # [END at_most_one_shift] # [START assign_nurses_evenly] diff --git a/ortools/sat/samples/overlapping_intervals_sample_sat.py b/ortools/sat/samples/overlapping_intervals_sample_sat.py index fc6e5915e5..d6ab672545 100755 --- a/ortools/sat/samples/overlapping_intervals_sample_sat.py +++ b/ortools/sat/samples/overlapping_intervals_sample_sat.py @@ -72,8 +72,8 @@ def OverlappingIntervals(): model.AddImplication(a_after_b, a_overlaps_b.Not()) model.AddImplication(b_after_a, a_overlaps_b.Not()) - # Option b: using a sum() == 1. - # model.Add(a_after_b + b_after_a + a_overlaps_b == 1) + # Option b: using an exactly one constraint. + # model.AddExactlyOne([a_after_b, b_after_a, a_overlaps_b]) # Search for start values in increasing order for the two intervals. model.AddDecisionStrategy([start_var_a, start_var_b], cp_model.CHOOSE_FIRST, diff --git a/ortools/sat/samples/schedule_requests_sat.cc b/ortools/sat/samples/schedule_requests_sat.cc index ecf33496cf..a0f59cbea8 100644 --- a/ortools/sat/samples/schedule_requests_sat.cc +++ b/ortools/sat/samples/schedule_requests_sat.cc @@ -115,12 +115,12 @@ void ScheduleRequestsSat() { // [START exactly_one_nurse] for (int d : all_days) { for (int s : all_shifts) { - LinearExpr sum; + std::vector nurses; for (int n : all_nurses) { auto key = std::make_tuple(n, d, s); - sum += shifts[key]; + nurses.push_back(shifts[key]); } - cp_model.AddEquality(sum, 1); + cp_model.AddExactlyOne(nurses); } } // [END exactly_one_nurse] @@ -129,12 +129,12 @@ void ScheduleRequestsSat() { // [START at_most_one_shift] for (int n : all_nurses) { for (int d : all_days) { - LinearExpr sum; + std::vector work; for (int s : all_shifts) { auto key = std::make_tuple(n, d, s); - sum += shifts[key]; + work.push_back(shifts[key]); } - cp_model.AddLessOrEqual(sum, 1); + cp_model.AddAtMostOne(work); } } // [END at_most_one_shift] diff --git a/ortools/sat/samples/schedule_requests_sat.py b/ortools/sat/samples/schedule_requests_sat.py index ba9c1a2f5f..fbffee2f5b 100755 --- a/ortools/sat/samples/schedule_requests_sat.py +++ b/ortools/sat/samples/schedule_requests_sat.py @@ -62,14 +62,14 @@ def main(): # [START exactly_one_nurse] for d in all_days: for s in all_shifts: - model.Add(sum(shifts[(n, d, s)] for n in all_nurses) == 1) + model.AddExactlyOne([shifts[(n, d, s)] for n in all_nurses]) # [END exactly_one_nurse] # Each nurse works at most one shift per day. # [START at_most_one_shift] for n in all_nurses: for d in all_days: - model.Add(sum(shifts[(n, d, s)] for s in all_shifts) <= 1) + model.AddAtMostOne([shifts[(n, d, s)] for s in all_shifts]) # [END at_most_one_shift] # [START assign_nurses_evenly] diff --git a/ortools/sat/samples/step_function_sample_sat.cc b/ortools/sat/samples/step_function_sample_sat.cc index 1190fd6e7b..0cd4c90223 100644 --- a/ortools/sat/samples/step_function_sample_sat.cc +++ b/ortools/sat/samples/step_function_sample_sat.cc @@ -54,7 +54,7 @@ void StepFunctionSampleSat() { cp_model.AddEquality(x, 7).OnlyEnforceIf(b3); cp_model.AddEquality(expr, 3).OnlyEnforceIf(b3); - // At least one bi is true. (we could use a sum == 1). + // At least one bi is true. (we could use an exactly one constraint). cp_model.AddBoolOr({b0, b2, b3}); // Search for x values in increasing order. diff --git a/ortools/sat/samples/step_function_sample_sat.py b/ortools/sat/samples/step_function_sample_sat.py index da7c38e9f4..5eb76c261d 100755 --- a/ortools/sat/samples/step_function_sample_sat.py +++ b/ortools/sat/samples/step_function_sample_sat.py @@ -72,7 +72,7 @@ def step_function_sample_sat(): model.Add(x == 7).OnlyEnforceIf(b3) model.Add(expr == 3).OnlyEnforceIf(b3) - # At least one bi is true. (we could use a sum == 1). + # At least one bi is true. (we could use an exactly one constraint). model.AddBoolOr([b0, b2, b3]) # Search for x values in increasing order. diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index b1235d3821..28f7bbdc04 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -19,6 +19,7 @@ #include "absl/numeric/int128.h" #include "ortools/base/stl_util.h" +#include "ortools/util/saturated_arithmetic.h" namespace operations_research { namespace sat { @@ -166,6 +167,26 @@ bool SolveDiophantineEquationOfSizeTwo(int64_t& a, int64_t& b, int64_t& cte, return true; } +// TODO(user): Find better implementation? In pratice passing via double is +// almost always correct, but the CapProd() might be a bit slow. However this +// is only called when we do propagate something. +int64_t FloorSquareRoot(int64_t a) { + int64_t result = + static_cast(std::floor(std::sqrt(static_cast(a)))); + while (CapProd(result, result) > a) --result; + while (CapProd(result + 1, result + 1) <= a) ++result; + return result; +} + +// TODO(user): Find better implementation? +int64_t CeilSquareRoot(int64_t a) { + int64_t result = + static_cast(std::ceil(std::sqrt(static_cast(a)))); + while (CapProd(result, result) < a) ++result; + while ((result - 1) * (result - 1) >= a) --result; + return result; +} + int MoveOneUnprocessedLiteralLast(const std::set& processed, int relevant_prefix_size, std::vector* literals) { diff --git a/ortools/sat/util.h b/ortools/sat/util.h index f5596de593..2ae5635f46 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -77,6 +77,10 @@ int64_t ProductWithModularInverse(int64_t coeff, int64_t mod, int64_t rhs); bool SolveDiophantineEquationOfSizeTwo(int64_t& a, int64_t& b, int64_t& cte, int64_t& x0, int64_t& y0); +// The argument must be non-negative. +int64_t FloorSquareRoot(int64_t a); +int64_t CeilSquareRoot(int64_t a); + // The model "singleton" random engine used in the solver. // // In test, we usually set use_absl_random() so that the sequence is changed at diff --git a/ortools/util/sorted_interval_list.cc b/ortools/util/sorted_interval_list.cc index 3a9393391c..453b7584f5 100644 --- a/ortools/util/sorted_interval_list.cc +++ b/ortools/util/sorted_interval_list.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -519,6 +520,32 @@ Domain Domain::PositiveDivisionBySuperset(const Domain& divisor) const { std::max(Max() / divisor.Min(), Max() / divisor.Max())); } +Domain Domain::SquareSuperset() const { + if (IsEmpty()) return Domain(); + const Domain abs_domain = + IntersectionWith({0, std::numeric_limits::max()}) + .UnionWith(Negation().IntersectionWith( + {0, std::numeric_limits::max()})); + if (abs_domain.Size() >= kDomainComplexityLimit) { + Domain result; + result.intervals_.reserve(abs_domain.NumIntervals()); + for (const auto& interval : abs_domain.intervals()) { + result.intervals_.push_back( + ClosedInterval(CapProd(interval.start, interval.start), + CapProd(interval.end, interval.end))); + } + UnionOfSortedIntervals(&result.intervals_); + return result; + } else { + std::vector values; + values.reserve(abs_domain.Size()); + for (const int64_t value : abs_domain.Values()) { + values.push_back(CapProd(value, value)); + } + return Domain::FromValues(values); + } +} + // It is a bit difficult to see, but this code is doing the same thing as // for all interval in this.UnionWith(implied_domain.Complement())): // - Take the two extreme points (min and max) in interval \inter implied. diff --git a/ortools/util/sorted_interval_list.h b/ortools/util/sorted_interval_list.h index 77215c283b..d3b37ddf6e 100644 --- a/ortools/util/sorted_interval_list.h +++ b/ortools/util/sorted_interval_list.h @@ -372,6 +372,11 @@ class Domain { */ Domain PositiveDivisionBySuperset(const Domain& divisor) const; + /** + * Returns a superset of {x ∈ Int64, ∃ y ∈ D, x = y * y }. + */ + Domain SquareSuperset() const; + /** * Advanced usage. Given some \e implied information on this domain that is * assumed to be always true (i.e. only values in the intersection with