diff --git a/ortools/sat/constraint_violation.cc b/ortools/sat/constraint_violation.cc index 3e24bfe1af..67d8690b4a 100644 --- a/ortools/sat/constraint_violation.cc +++ b/ortools/sat/constraint_violation.cc @@ -16,15 +16,13 @@ #include #include #include -#include #include #include -#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "ortools/base/logging.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" -#include "ortools/util/saturated_arithmetic.h" namespace operations_research { namespace sat { @@ -47,18 +45,18 @@ bool LiteralValue(int lit, absl::Span solution) { int LinearIncrementalEvaluator::NewConstraint(int64_t lb, int64_t ub) { const int ct_index = lower_bounds_.size(); - enforcement_literals_.resize(ct_index + 1); + num_enforcement_literals_.push_back(0); lower_bounds_.push_back(lb); upper_bounds_.push_back(ub); offsets_.push_back(0); - enforced_.push_back(true); + num_true_enforcement_literals_.push_back(0); activities_.push_back(0); return ct_index; } void LinearIncrementalEvaluator::AddEnforcementLiteral(int ct_index, int lit) { // Update the row-major storage. - enforcement_literals_[ct_index].push_back(lit); + num_enforcement_literals_[ct_index]++; // Update the column-major storage. const int var = PositiveRef(lit); @@ -92,23 +90,16 @@ void LinearIncrementalEvaluator::AddOffset(int ct_index, int64_t offset) { void LinearIncrementalEvaluator::ComputeInitialActivities( absl::Span solution) { - // Process enforcement literals. - for (int ct_index = 0; ct_index < enforced_.size(); ++ct_index) { - // Resets the activity as the offset. - activities_[ct_index] = offsets_[ct_index]; + const int num_vars = var_entries_.size(); + const int num_constraints = num_enforcement_literals_.size(); - // Checks if the constraint is enforced. - enforced_[ct_index] = true; - for (const int lit : enforcement_literals_[ct_index]) { - if (!LiteralValue(lit, solution)) { - enforced_[ct_index] = false; - return; - } - } - } + // Resets the activity as the offset. + activities_ = offsets_; // Updates activities from variables and coefficients. - for (int var = 0; var < var_entries_.size(); ++var) { + for (int var = 0; var < num_vars; ++var) { + if (var >= var_entries_.size()) break; + const int64_t value = solution[var]; if (value == 0) continue; for (const auto& entry : var_entries_[var]) { @@ -116,30 +107,33 @@ void LinearIncrementalEvaluator::ComputeInitialActivities( } } - if (VLOG_IS_ON(2)) { - for (int ct_index = 0; ct_index < enforced_.size(); ++ct_index) { - VLOG(2) << "Linear(" << ct_index - << "): enforced = " << enforced_[ct_index] - << ", activities = " << activities_[ct_index]; + // Reset the num_true_enforcement_literals_. + num_true_enforcement_literals_.assign(num_constraints, 0); + + // Update enforcement literals count. + for (int var = 0; var < num_vars; ++var) { + if (var >= literal_entries_.size()) break; + + const bool literal_is_true = solution[var] != 0; + for (const auto& entry : literal_entries_[var]) { + if (literal_is_true == entry.positive) { + num_true_enforcement_literals_[entry.ct_index]++; + } } } } void LinearIncrementalEvaluator::Update(int var, int64_t old_value, - absl::Span solution) { + int64_t new_value) { + DCHECK_NE(old_value, new_value); if (var < literal_entries_.size()) { + const bool literal_is_true = new_value != 0; for (const LiteralEntry& entry : literal_entries_[var]) { const int ct_index = entry.ct_index; - if ((solution[var] == 0) == entry.positive) { - enforced_[ct_index] = false; + if (literal_is_true == entry.positive) { + num_true_enforcement_literals_[ct_index]++; } else { - enforced_[ct_index] = true; - for (const int lit : enforcement_literals_[ct_index]) { - if (!LiteralValue(lit, solution)) { - enforced_[ct_index] = false; - return; - } - } + num_true_enforcement_literals_[ct_index]--; } } } @@ -147,53 +141,69 @@ void LinearIncrementalEvaluator::Update(int var, int64_t old_value, if (var < var_entries_.size()) { for (const auto& entry : var_entries_[var]) { activities_[entry.ct_index] += - entry.coefficient * (solution[var] - old_value); + entry.coefficient * (new_value - old_value); } } - if (VLOG_IS_ON(2)) { - for (int ct_index = 0; ct_index < enforced_.size(); ++ct_index) { + if (DEBUG_MODE) { + for (int ct_index = 0; ct_index < num_enforcement_literals_.size(); + ++ct_index) { + DCHECK_GE(num_true_enforcement_literals_[ct_index], 0); + DCHECK_LE(num_true_enforcement_literals_[ct_index], + num_enforcement_literals_[ct_index]); } } } int64_t LinearIncrementalEvaluator::Violation(int ct_index) const { - if (!enforced_[ct_index]) { - VLOG(2) << "Linear(" << ct_index << "): violation = 0 as not enforced"; + if (num_true_enforcement_literals_[ct_index] < + num_enforcement_literals_[ct_index]) { return 0; } const int64_t act = activities_[ct_index]; if (act < lower_bounds_[ct_index]) { - VLOG(2) << "Linear(" << ct_index - << "): violation = " << lower_bounds_[ct_index] - act; return lower_bounds_[ct_index] - act; } else if (act > upper_bounds_[ct_index]) { - VLOG(2) << "Linear(" << ct_index - << "): violation = " << act - upper_bounds_[ct_index]; return act - upper_bounds_[ct_index]; } else { - VLOG(2) << "Linear(" << ct_index << "): violation = 0"; return 0; } } -void LinearIncrementalEvaluator::ResetObjectiveBounds(int64_t lb, int64_t ub) { - lower_bounds_[0] = lb; - upper_bounds_[0] = ub; +void LinearIncrementalEvaluator::ResetBounds(int ct_index, int64_t lb, + int64_t ub) { + lower_bounds_[ct_index] = lb; + upper_bounds_[ct_index] = ub; } +// ----- CompiledConstraint ----- + +CompiledConstraint::CompiledConstraint(const ConstraintProto& proto) + : proto_(proto) {} + // ----- CompiledLinMaxConstraint ----- -CompiledLinMaxConstraint::CompiledLinMaxConstraint( - const LinearArgumentProto& proto) - : proto_(proto) {} +// The violation of a lin_max constraint is: +// - the sum(max(0, expr_value - target_value) forall expr) +// - min(target_value - expr_value for all expr) if the above sum is 0 +class CompiledLinMaxConstraint : public CompiledConstraint { + public: + explicit CompiledLinMaxConstraint(const ConstraintProto& proto); + ~CompiledLinMaxConstraint() override = default; -int64_t CompiledLinMaxConstraint::Evaluate(absl::Span solution) { - const int64_t target_value = ExprValue(proto_.target(), solution); + int64_t Violation(absl::Span solution) override; +}; + +CompiledLinMaxConstraint::CompiledLinMaxConstraint(const ConstraintProto& proto) + : CompiledConstraint(proto) {} + +int64_t CompiledLinMaxConstraint::Violation( + absl::Span solution) { + const int64_t target_value = ExprValue(proto().lin_max().target(), solution); int64_t sum_of_excesses = 0; int64_t min_missing_quantities = std::numeric_limits::max(); - for (const LinearExpressionProto& expr : proto_.exprs()) { + for (const LinearExpressionProto& expr : proto().lin_max().exprs()) { const int64_t expr_value = ExprValue(expr, solution); if (expr_value <= target_value) { min_missing_quantities = @@ -209,57 +219,59 @@ int64_t CompiledLinMaxConstraint::Evaluate(absl::Span solution) { } } -void CompiledLinMaxConstraint::CallOnEachVariable( - std::function func) const { - for (const int var : proto_.target().vars()) { - func(var); - } - for (const LinearExpressionProto& expr : proto_.exprs()) { - for (const int var : expr.vars()) { - func(var); - } - } -} +// ----- CompiledIntProdConstraint ----- -// ----- CompiledProductConstraint ----- +// The violation of an int_prod constraint is +// abs(value(target) - prod(value(expr)). +class CompiledIntProdConstraint : public CompiledConstraint { + public: + explicit CompiledIntProdConstraint(const ConstraintProto& proto); + ~CompiledIntProdConstraint() override = default; -CompiledProductConstraint::CompiledProductConstraint( - const LinearArgumentProto& proto) - : proto_(proto) {} + int64_t Violation(absl::Span solution) override; +}; -int64_t CompiledProductConstraint::Evaluate( +CompiledIntProdConstraint::CompiledIntProdConstraint( + const ConstraintProto& proto) + : CompiledConstraint(proto) {} + +int64_t CompiledIntProdConstraint::Violation( absl::Span solution) { - const int64_t target_value = ExprValue(proto_.target(), solution); - CHECK_EQ(proto_.exprs_size(), 2); - const int64_t prod_value = ExprValue(proto_.exprs(0), solution) * - ExprValue(proto_.exprs(1), solution); + const int64_t target_value = ExprValue(proto().int_prod().target(), solution); + CHECK_EQ(proto().int_prod().exprs_size(), 2); + const int64_t prod_value = ExprValue(proto().int_prod().exprs(0), solution) * + ExprValue(proto().int_prod().exprs(1), solution); return std::abs(target_value - prod_value); } -void CompiledProductConstraint::CallOnEachVariable( - std::function func) const { - for (const int var : proto_.target().vars()) { - func(var); - } - for (const LinearExpressionProto& expr : proto_.exprs()) { - for (const int var : expr.vars()) { - func(var); - } - } +// ----- CompiledIntDivConstraint ----- + +// The violation of an int_div constraint is +// abs(value(target) - value(expr0) / value(expr1)). +class CompiledIntDivConstraint : public CompiledConstraint { + public: + explicit CompiledIntDivConstraint(const ConstraintProto& proto); + ~CompiledIntDivConstraint() override = default; + + int64_t Violation(absl::Span solution) override; +}; + +CompiledIntDivConstraint::CompiledIntDivConstraint(const ConstraintProto& proto) + : CompiledConstraint(proto) {} + +int64_t CompiledIntDivConstraint::Violation( + absl::Span solution) { + const int64_t target_value = ExprValue(proto().int_div().target(), solution); + CHECK_EQ(proto().int_div().exprs_size(), 2); + const int64_t div_value = ExprValue(proto().int_div().exprs(0), solution) / + ExprValue(proto().int_div().exprs(1), solution); + return std::abs(target_value - div_value); } // ----- LsEvaluator ----- LsEvaluator::LsEvaluator(const CpModelProto& model) : model_(model) { var_to_constraint_graph_.resize(model_.variables_size()); -} - -void LsEvaluator::UpdateVariableDomains(const CpModelProto& variables_only) { - *model_.mutable_variables() = variables_only.variables(); - CompileModel(); -} - -void LsEvaluator::CompileModel() { CompileConstraintsAndObjective(); BuildVarConstraintGraph(); } @@ -271,15 +283,8 @@ void LsEvaluator::BuildVarConstraintGraph() { } // Build the constraint graph. - const auto collect = [this](int var) { - if (VariableIsFixed(var)) return; - tmp_vars_.insert(var); - }; - for (int ct_index = 0; ct_index < constraints_.size(); ++ct_index) { - tmp_vars_.clear(); - constraints_[ct_index]->CallOnEachVariable(collect); - for (const int var : tmp_vars_) { + for (const int var : UsedVariables(constraints_[ct_index]->proto())) { var_to_constraint_graph_[var].push_back(ct_index); } } @@ -297,11 +302,7 @@ void LsEvaluator::CompileConstraintsAndObjective() { for (int i = 0; i < model_.objective().vars_size(); ++i) { const int var = model_.objective().vars(i); const int64_t coeff = model_.objective().coeffs(i); - if (VariableIsFixed(var)) { - linear_evaluator_.AddOffset(ct_index, VariableValue(var) * coeff); - } else { - linear_evaluator_.AddTerm(ct_index, var, coeff); - } + linear_evaluator_.AddTerm(ct_index, var, coeff); } } @@ -346,27 +347,21 @@ void LsEvaluator::CompileConstraintsAndObjective() { } break; } - case ConstraintProto::ConstraintCase::kBoolXor: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kIntDiv: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kIntMod: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; case ConstraintProto::ConstraintCase::kLinMax: { - CompiledLinMaxConstraint* lin_max = - new CompiledLinMaxConstraint(ct.lin_max()); + CompiledLinMaxConstraint* lin_max = new CompiledLinMaxConstraint(ct); constraints_.emplace_back(lin_max); break; } case ConstraintProto::ConstraintCase::kIntProd: { - CompiledProductConstraint* int_prod = - new CompiledProductConstraint(ct.int_prod()); + CompiledIntProdConstraint* int_prod = new CompiledIntProdConstraint(ct); constraints_.emplace_back(int_prod); break; } + case ConstraintProto::ConstraintCase::kIntDiv: { + CompiledIntDivConstraint* int_div = new CompiledIntDivConstraint(ct); + constraints_.emplace_back(int_div); + break; + } case ConstraintProto::ConstraintCase::kLinear: { CHECK_EQ(ct.linear().domain_size(), 2); const int ct_index = linear_evaluator_.NewConstraint( @@ -377,62 +372,20 @@ void LsEvaluator::CompileConstraintsAndObjective() { for (int i = 0; i < ct.linear().vars_size(); ++i) { const int var = ct.linear().vars(i); const int64_t coeff = ct.linear().coeffs(i); - if (VariableIsFixed(var)) { - linear_evaluator_.AddOffset(ct_index, VariableValue(var) * coeff); - } else { - linear_evaluator_.AddTerm(ct_index, var, coeff); - } + linear_evaluator_.AddTerm(ct_index, var, coeff); } break; } - case ConstraintProto::ConstraintCase::kAllDiff: + default: LOG(FATAL) << "Not implemented" << ct.constraint_case(); break; - case ConstraintProto::ConstraintCase::kDummyConstraint: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kElement: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kCircuit: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kRoutes: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kInverse: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kReservoir: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kTable: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kAutomaton: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kInterval: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kNoOverlap: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kNoOverlap2D: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::kCumulative: - LOG(FATAL) << "Not implemented" << ct.constraint_case(); - break; - case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: - break; } } } void LsEvaluator::SetObjectiveBounds(int64_t lb, int64_t ub) { if (!model_.has_objective()) return; - linear_evaluator_.ResetObjectiveBounds(lb, ub); + linear_evaluator_.ResetBounds(/*ct_index=*/0, lb, ub); } void LsEvaluator::ComputeInitialViolations(absl::Span solution) { @@ -443,7 +396,7 @@ void LsEvaluator::ComputeInitialViolations(absl::Span solution) { // Generic constraints. for (const auto& ct : constraints_) { - const int64_t ct_eval = ct->Evaluate(solution); + const int64_t ct_eval = ct->Violation(solution); ct->clear_violations(); ct->push_violation(ct_eval); } @@ -456,9 +409,10 @@ void LsEvaluator::UpdateVariableValueAndRecomputeViolations(int var, if (old_value == new_value) return; current_solution_[var] = new_value; - linear_evaluator_.Update(var, old_value, current_solution_); + linear_evaluator_.Update(var, old_value, new_value); for (const int ct_index : var_to_constraint_graph_[var]) { - const int64_t ct_eval = constraints_[ct_index]->Evaluate(current_solution_); + const int64_t ct_eval = + constraints_[ct_index]->Violation(current_solution_); constraints_[ct_index]->clear_violations(); constraints_[ct_index]->push_violation(ct_eval); } @@ -468,9 +422,8 @@ int64_t LsEvaluator::SumOfViolations() { int64_t evaluation = 0; // Process the linear part. - for (int lin_index = 0; lin_index < linear_evaluator_.num_constraints(); - ++lin_index) { - evaluation += linear_evaluator_.Violation(lin_index); + for (int i = 0; i < linear_evaluator_.num_constraints(); ++i) { + evaluation += linear_evaluator_.Violation(i); } // Process the generic constraint part. @@ -480,44 +433,5 @@ int64_t LsEvaluator::SumOfViolations() { return evaluation; } -bool LsEvaluator::VariableIsFixed(int ref) const { - const int var = PositiveRef(ref); - const IntegerVariableProto& var_proto = model_.variables(var); - return var_proto.domain_size() == 2 && - var_proto.domain(0) == var_proto.domain(1); -} - -int64_t LsEvaluator::VariableMin(int var) const { - CHECK(RefIsPositive(var)); - const IntegerVariableProto& var_proto = model_.variables(var); - return var_proto.domain(0); -} - -int64_t LsEvaluator::VariableMax(int var) const { - CHECK(RefIsPositive(var)); - const IntegerVariableProto& var_proto = model_.variables(var); - return var_proto.domain(var_proto.domain_size() - 1); -} - -int64_t LsEvaluator::VariableValue(int var) const { - CHECK(VariableIsFixed(var)); - return VariableMin(var); -} - -bool LsEvaluator::LiteralValue(int lit) const { - CHECK(VariableIsFixed(lit)); - return RefIsPositive(lit) == (VariableValue(PositiveRef(lit)) == 1); -} - -bool LsEvaluator::LiteralIsFalse(int lit) const { - if (!VariableIsFixed(lit)) return false; - return RefIsPositive(lit) == (VariableValue(PositiveRef(lit)) == 0); -} - -bool LsEvaluator::LiteralIsTrue(int lit) const { - if (!VariableIsFixed(lit)) return false; - return RefIsPositive(lit) == (VariableValue(PositiveRef(lit)) == 1); -} - } // namespace sat } // namespace operations_research diff --git a/ortools/sat/constraint_violation.h b/ortools/sat/constraint_violation.h index 57491f43a6..3386a35c5a 100644 --- a/ortools/sat/constraint_violation.h +++ b/ortools/sat/constraint_violation.h @@ -15,35 +15,20 @@ #define OR_TOOLS_SAT_CONSTRAINT_VIOLATION_H_ #include -#include -#include #include #include -#include "absl/container/flat_hash_set.h" #include "ortools/sat/cp_model.pb.h" namespace operations_research { namespace sat { -class LsEvaluator; - bool LiteralValue(int lit, absl::Span solution); int64_t ExprValue(const LinearExpressionProto& expr, - absl ::Span solution); + absl::Span solution); class LinearIncrementalEvaluator { public: - struct Entry { - int ct_index; - int64_t coefficient; - }; - - struct LiteralEntry { - int ct_index; - bool positive; - }; - LinearIncrementalEvaluator() = default; // Returns the index of the new constraint. @@ -57,20 +42,30 @@ class LinearIncrementalEvaluator { // Compute activities and query violations. void ComputeInitialActivities(absl::Span solution); - void Update(int var, int64_t old_value, absl::Span solution); + void Update(int var, int64_t old_value, int64_t new_value); int64_t Violation(int ct_index) const; - // Manage the objective. - void ResetObjectiveBounds(int64_t lb, int64_t ub); + // Update constraint bounds. + void ResetBounds(int ct_index, int64_t lb, int64_t ub); // Model getters. int num_constraints() const { return activities_.size(); } private: + // Cell in the sparse matrix. + struct Entry { + int ct_index; + int64_t coefficient; + }; + + // Column-view of the enforcement literals. + struct LiteralEntry { + int ct_index; + bool positive; // bool_var or its negation. + }; + // Model data. - // TODO(user): We should store the constraint proto here instead of the - // enforcement literals. Just a bit problematic with the objective. - std::vector> enforcement_literals_; + std::vector num_enforcement_literals_; std::vector lower_bounds_; std::vector upper_bounds_; @@ -81,20 +76,20 @@ class LinearIncrementalEvaluator { // Dynamic data. std::vector activities_; - std::vector enforced_; + std::vector num_true_enforcement_literals_; }; // View of a generic (non linear) constraint for the LsEvaluator. class CompiledConstraint { public: - explicit CompiledConstraint() = default; + explicit CompiledConstraint(const ConstraintProto& proto); virtual ~CompiledConstraint() = default; - // Evaluation. - virtual int64_t Evaluate(absl::Span solution) = 0; - - // Utilities to compute the var <-> constraint graph. - virtual void CallOnEachVariable(std::function func) const = 0; + // Computes the violation of a constraint. + // + // A violation is a positive integer value. A zero value means the constraint + // is not violated.. + virtual int64_t Violation(absl::Span solution) = 0; // Violations are stored in a stack for each constraint. int64_t current_violation() const { return violations_.back(); } @@ -102,50 +97,19 @@ class CompiledConstraint { void pop_violation() { violations_.pop_back(); } void clear_violations() { violations_.clear(); } + const ConstraintProto& proto() const { return proto_; } + private: + const ConstraintProto& proto_; std::vector violations_; }; -// The violation of a lin_max constraint is: -// - the sum(max(0, expr_value - target_value) forall expr) -// - min(target_value - expr_value for all expr) if the above sum is 0 -class CompiledLinMaxConstraint : public CompiledConstraint { - public: - explicit CompiledLinMaxConstraint(const LinearArgumentProto& proto); - ~CompiledLinMaxConstraint() override = default; - - int64_t Evaluate(absl::Span solution) override; - void CallOnEachVariable(std::function func) const override; - - private: - const LinearArgumentProto& proto_; -}; - -// The violation of an int_prod constraint is -// abs(target_value - prod(expr value)). -class CompiledProductConstraint : public CompiledConstraint { - public: - explicit CompiledProductConstraint(const LinearArgumentProto& proto); - ~CompiledProductConstraint() override = default; - - int64_t Evaluate(absl::Span solution) override; - void CallOnEachVariable(std::function func) const override; - - private: - const LinearArgumentProto& proto_; -}; - // Evaluation container for the local search. class LsEvaluator { public: + // The model must outlive this class. explicit LsEvaluator(const CpModelProto& model); - // Assigns the variable domains from variables_only to the current storage. - void UpdateVariableDomains(const CpModelProto& variables_only); - - // Compiles the current model into a efficient representation. - void CompileModel(); - // Overwrites the bounds of the objective. void SetObjectiveBounds(int64_t lb, int64_t ub); @@ -158,25 +122,15 @@ class LsEvaluator { // Simple summation metric for the constraint and objective violations. int64_t SumOfViolations(); - // Getters for variables using the domains set with UpdateVariableDomains(). - bool VariableIsFixed(int ref) const; - int64_t VariableMin(int var) const; - int64_t VariableMax(int var) const; - int64_t VariableValue(int var) const; - bool LiteralValue(int lit) const; - bool LiteralIsFalse(int lit) const; - bool LiteralIsTrue(int lit) const; - private: void CompileConstraintsAndObjective(); void BuildVarConstraintGraph(); - CpModelProto model_; + const CpModelProto& model_; LinearIncrementalEvaluator linear_evaluator_; std::vector> constraints_; std::vector> var_to_constraint_graph_; std::vector current_solution_; - absl::flat_hash_set tmp_vars_; }; } // namespace sat