From 1a8a94dace374febf96ca590dd608bd927a94fa4 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Wed, 2 Feb 2022 14:57:42 +0100 Subject: [PATCH] [CP-SAT] add more incremental methods on java constraints; use logger in SatSolver; --- examples/tests/CpModelTest.java | 8 ++ examples/tests/CpSolverTest.java | 91 +++++++++++++++++++ .../ortools/sat/CumulativeConstraint.java | 57 ++++++++++++ .../com/google/ortools/sat/LinearExpr.java | 5 + ortools/sat/BUILD.bazel | 1 + ortools/sat/sat_solver.cc | 50 +++++----- ortools/sat/sat_solver.h | 2 + 7 files changed, 187 insertions(+), 27 deletions(-) diff --git a/examples/tests/CpModelTest.java b/examples/tests/CpModelTest.java index bebb6db0d8..52fd323f0c 100644 --- a/examples/tests/CpModelTest.java +++ b/examples/tests/CpModelTest.java @@ -429,6 +429,14 @@ public final class CpModelTest { assertThat(model.model().getConstraints(1).hasInterval()).isTrue(); assertThat(model.model().getConstraints(2).hasCumulative()).isTrue(); assertThat(model.model().getConstraints(2).getCumulative().getIntervalsCount()).isEqualTo(2); + + cumul.addDemands(new IntervalVar[] {interval1}, new int[] {2}); + cumul.addDemands(new IntervalVar[] {interval1}, new long[] {2}); + cumul.addDemands( + new IntervalVar[] {interval2}, new LinearArgument[] {LinearExpr.affine(demandVar2, 1, 2)}); + cumul.addDemands( + new IntervalVar[] {interval2}, new LinearExpr[] {LinearExpr.affine(demandVar2, 1, 3)}); + assertThat(model.model().getConstraints(2).getCumulative().getIntervalsCount()).isEqualTo(6); } @Test diff --git a/examples/tests/CpSolverTest.java b/examples/tests/CpSolverTest.java index fcf03d6e9a..4bb1cebea1 100644 --- a/examples/tests/CpSolverTest.java +++ b/examples/tests/CpSolverTest.java @@ -19,6 +19,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import com.google.ortools.Loader; import com.google.ortools.sat.CpSolverStatus; +import com.google.ortools.util.Domain; import java.util.function.Consumer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -277,4 +278,94 @@ public final class CpSolverTest { assertThat(log).contains("log_to_stdout: false"); assertThat(log).contains("OPTIMAL"); } + + @Test + public void issue3108() { + final CpModel model = new CpModel(); + final IntVar var1 = model.newIntVar(0, 1, "CONTROLLABLE__C1[0]"); + final IntVar var2 = model.newIntVar(0, 1, "CONTROLLABLE__C1[1]"); + capacityConstraint(model, new IntVar[] {var1, var2}, new long[] {0L, 1L}, + new long[][] {new long[] {1L, 1L}}, new long[][] {new long[] {1L, 1L}}); + final CpSolver solver = new CpSolver(); + solver.getParameters().setLogSearchProgress(false); + solver.getParameters().setCpModelProbingLevel(0); + solver.getParameters().setNumSearchWorkers(4); + solver.getParameters().setMaxTimeInSeconds(1); + final CpSolverStatus status = solver.solve(model); + assertEquals(status, CpSolverStatus.OPTIMAL); + } + + private static void capacityConstraint(final CpModel model, final IntVar[] varsToAssign, + final long[] domainArr, final long[][] demands, final long[][] capacities) { + final int numTasks = varsToAssign.length; + final int numResources = demands.length; + final IntervalVar[] tasksIntervals = new IntervalVar[numTasks + capacities[0].length]; + + final Domain domainT = Domain.fromValues(domainArr); + final Domain intervalRange = + Domain.fromFlatIntervals(new long[] {domainT.min() + 1, domainT.max() + 1}); + final int unitIntervalSize = 1; + for (int i = 0; i < numTasks; i++) { + final BoolVar presence = model.newBoolVar(""); + model.addLinearExpressionInDomain(varsToAssign[i], domainT).onlyEnforceIf(presence); + model.addLinearExpressionInDomain(varsToAssign[i], domainT.complement()) + .onlyEnforceIf(presence.not()); + // interval with start as taskToNodeAssignment and size of 1 + tasksIntervals[i] = + model.newOptionalFixedSizeIntervalVar(varsToAssign[i], unitIntervalSize, presence, ""); + } + + // Create dummy intervals + for (int i = numTasks; i < tasksIntervals.length; i++) { + final int nodeIndex = i - numTasks; + tasksIntervals[i] = model.newFixedInterval(domainArr[nodeIndex], 1, ""); + } + + // Convert to list of arrays + final long[][] nodeCapacities = new long[numResources][]; + final long[] maxCapacities = new long[numResources]; + + for (int i = 0; i < capacities.length; i++) { + final long[] capacityArr = capacities[i]; + long maxCapacityValue = Long.MIN_VALUE; + for (int j = 0; j < capacityArr.length; j++) { + maxCapacityValue = Math.max(maxCapacityValue, capacityArr[j]); + } + nodeCapacities[i] = capacityArr; + maxCapacities[i] = maxCapacityValue; + } + + // For each resource, create dummy demands to accommodate heterogeneous capacities + final long[][] updatedDemands = new long[numResources][]; + for (int i = 0; i < numResources; i++) { + final long[] demand = new long[numTasks + capacities[0].length]; + + // copy ver task demands + int iter = 0; + for (final long taskDemand : demands[i]) { + demand[iter] = taskDemand; + iter++; + } + + // copy over dummy demands + final long maxCapacity = maxCapacities[i]; + for (final long nodeHeterogeneityAdjustment : nodeCapacities[i]) { + demand[iter] = maxCapacity - nodeHeterogeneityAdjustment; + iter++; + } + updatedDemands[i] = demand; + } + + // 2. Capacity constraints + for (int i = 0; i < numResources; i++) { + model.addCumulative(maxCapacities[i]).addDemands(tasksIntervals, updatedDemands[i]); + } + + // Cumulative score + for (int i = 0; i < numResources; i++) { + final IntVar max = model.newIntVar(0, maxCapacities[i], ""); + model.addCumulative(max).addDemands(tasksIntervals, updatedDemands[i]).getBuilder(); + model.minimize(max); + } + } } diff --git a/ortools/java/com/google/ortools/sat/CumulativeConstraint.java b/ortools/java/com/google/ortools/sat/CumulativeConstraint.java index e76a3b3c67..9ab386db9d 100644 --- a/ortools/java/com/google/ortools/sat/CumulativeConstraint.java +++ b/ortools/java/com/google/ortools/sat/CumulativeConstraint.java @@ -43,5 +43,62 @@ public class CumulativeConstraint extends Constraint { return this; } + /** + * Adds all pairs (intervals[i], demands[i]) to the constraint. + * + * @param intervals an array of interval variables + * @param deamds an array of linear expression + * @return itself + * @throws CpModel.MismatchedArrayLengths if intervals and demands have different length + */ + public CumulativeConstraint addDemands(IntervalVar[] intervals, LinearArgument[] demands) { + if (intervals.length != demands.length) { + throw new CpModel.MismatchedArrayLengths( + "CumulativeConstraint.addDemands", "intervals", "demands"); + } + for (int i = 0; i < intervals.length; i++) { + addDemand(intervals[i], demands[i]); + } + return this; + } + + /** + * Adds all pairs (intervals[i], demands[i]) to the constraint. + * + * @param intervals an array of interval variables + * @param deamds an array of long values + * @return itself + * @throws CpModel.MismatchedArrayLengths if intervals and demands have different length + */ + public CumulativeConstraint addDemands(IntervalVar[] intervals, long[] demands) { + if (intervals.length != demands.length) { + throw new CpModel.MismatchedArrayLengths( + "CumulativeConstraint.addDemands", "intervals", "demands"); + } + for (int i = 0; i < intervals.length; i++) { + addDemand(intervals[i], demands[i]); + } + return this; + } + + /** + * Adds all pairs (intervals[i], demands[i]) to the constraint. + * + * @param intervals an array of interval variables + * @param deamds an array of integer values + * @return itself + * @throws CpModel.MismatchedArrayLengths if intervals and demands have different length + */ + public CumulativeConstraint addDemands(IntervalVar[] intervals, int[] demands) { + if (intervals.length != demands.length) { + throw new CpModel.MismatchedArrayLengths( + "CumulativeConstraint.addDemands", "intervals", "demands"); + } + for (int i = 0; i < intervals.length; i++) { + addDemand(intervals[i], demands[i]); + } + return this; + } + private final CpModel model; } diff --git a/ortools/java/com/google/ortools/sat/LinearExpr.java b/ortools/java/com/google/ortools/sat/LinearExpr.java index be6e61ea2f..8372faae2a 100644 --- a/ortools/java/com/google/ortools/sat/LinearExpr.java +++ b/ortools/java/com/google/ortools/sat/LinearExpr.java @@ -44,6 +44,11 @@ public interface LinearExpr extends LinearArgument { return newBuilder().addTerm(expr, coeff).build(); } + /** Shortcut for newBuilder().addTerm(expr, coeff).add(offset).build() */ + static LinearExpr affine(LinearArgument expr, long coeff, long offset) { + return newBuilder().addTerm(expr, coeff).add(offset).build(); + } + /** Shortcut for newBuilder().addSum(exprs).build() */ static LinearExpr sum(LinearArgument[] exprs) { return newBuilder().addSum(exprs).build(); diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 9f0b6b699c..c55fab5692 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -458,6 +458,7 @@ cc_library( "//ortools/base:strong_vector", "//ortools/port:proto_utils", "//ortools/port:sysinfo", + "//ortools/util:logging", "//ortools/util:saturated_arithmetic", "//ortools/util:stats", "//ortools/util:time_limit", diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 1c2d041da7..4c032b420b 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -38,6 +38,7 @@ namespace sat { SatSolver::SatSolver() : SatSolver(new Model()) { owned_model_.reset(model_); model_->Register(this); + logger_ = model_->GetOrCreate(); } SatSolver::SatSolver(Model* model) @@ -51,6 +52,7 @@ SatSolver::SatSolver(Model* model) parameters_(model->GetOrCreate()), restart_(model->GetOrCreate()), decision_policy_(model->GetOrCreate()), + logger_(model->GetOrCreate()), clause_activity_increment_(1.0), same_reason_identifier_(*trail_), is_relevant_for_core_computation_(true), @@ -118,6 +120,8 @@ void SatSolver::SetParameters(const SatParameters& parameters) { *parameters_ = parameters; restart_->Reset(); time_limit_->ResetLimitFromParameters(parameters); + logger_->EnableLogging(parameters.log_search_progress() || VLOG_IS_ON(1)); + logger_->SetLogToStdOut(parameters.log_to_stdout()); } bool SatSolver::IsMemoryLimitReached() const { @@ -956,10 +960,8 @@ SatSolver::Status SatSolver::ResetAndSolveWithGivenAssumptions( } SatSolver::Status SatSolver::StatusWithLog(Status status) { - if (parameters_->log_search_progress()) { - LOG(INFO) << RunningStatisticsString(); - LOG(INFO) << StatusString(status); - } + SOLVER_LOG(logger_, RunningStatisticsString()); + SOLVER_LOG(logger_, StatusString(status)); return status; } @@ -1129,19 +1131,19 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit) { timer_.Restart(); // Display initial statistics. - if (parameters_->log_search_progress()) { - LOG(INFO) << "Initial memory usage: " << MemoryUsage(); - LOG(INFO) << "Number of variables: " << num_variables_; - LOG(INFO) << "Number of clauses (size > 2): " - << clauses_propagator_->num_clauses(); - LOG(INFO) << "Number of binary clauses: " - << binary_implication_graph_->num_implications(); - LOG(INFO) << "Number of linear constraints: " - << pb_constraints_->NumberOfConstraints(); - LOG(INFO) << "Number of fixed variables: " << trail_->Index(); - LOG(INFO) << "Number of watched clauses: " - << clauses_propagator_->num_watched_clauses(); - LOG(INFO) << "Parameters: " << ProtobufShortDebugString(*parameters_); + if (logger_->LoggingIsEnabled()) { + SOLVER_LOG(logger_, "Initial memory usage: ", MemoryUsage()); + SOLVER_LOG(logger_, "Number of variables: ", num_variables_.value()); + SOLVER_LOG(logger_, "Number of clauses (size > 2): ", + clauses_propagator_->num_clauses()); + SOLVER_LOG(logger_, "Number of binary clauses: ", + binary_implication_graph_->num_implications()); + SOLVER_LOG(logger_, "Number of linear constraints: ", + pb_constraints_->NumberOfConstraints()); + SOLVER_LOG(logger_, "Number of fixed variables: ", trail_->Index()); + SOLVER_LOG(logger_, "Number of watched clauses: ", + clauses_propagator_->num_watched_clauses()); + SOLVER_LOG(logger_, "Parameters: ", ProtobufShortDebugString(*parameters_)); } // Used to trigger clause minimization via propagation. @@ -1174,16 +1176,12 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit) { if (time_limit != nullptr) { AdvanceDeterministicTime(time_limit); if (time_limit->LimitReached()) { - if (parameters_->log_search_progress()) { - LOG(INFO) << "The time limit has been reached. Aborting."; - } + SOLVER_LOG(logger_, "The time limit has been reached. Aborting."); return StatusWithLog(LIMIT_REACHED); } } if (num_failures() >= kFailureLimit) { - if (parameters_->log_search_progress()) { - LOG(INFO) << "The conflict limit has been reached. Aborting."; - } + SOLVER_LOG(logger_, "The conflict limit has been reached. Aborting."); return StatusWithLog(LIMIT_REACHED); } @@ -1195,9 +1193,7 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit) { if (counters_.num_failures >= next_memory_check) { next_memory_check = NextMultipleOf(num_failures(), kMemoryCheckFrequency); if (IsMemoryLimitReached()) { - if (parameters_->log_search_progress()) { - LOG(INFO) << "The memory limit has been reached. Aborting."; - } + SOLVER_LOG(logger_, "The memory limit has been reached. Aborting."); return StatusWithLog(LIMIT_REACHED); } } @@ -1205,7 +1201,7 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit) { // Display search progression. We use >= because counters_.num_failures may // augment by more than one at each iteration. if (counters_.num_failures >= next_display) { - LOG(INFO) << RunningStatisticsString(); + SOLVER_LOG(logger_, RunningStatisticsString()); next_display = NextMultipleOf(num_failures(), kDisplayFrequency); } diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 0ac2c460a7..5e6ba890f6 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -43,6 +43,7 @@ #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_decision.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/util/logging.h" #include "ortools/util/stats.h" #include "ortools/util/time_limit.h" @@ -724,6 +725,7 @@ class SatSolver { SatParameters* parameters_; RestartPolicy* restart_; SatDecisionPolicy* decision_policy_; + SolverLogger* logger_; // Used for debugging only. See SaveDebugAssignment(). VariablesAssignment debug_assignment_;