[CP-SAT] polish constraint_violation code
This commit is contained in:
@@ -16,15 +16,13 @@
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/sat/cp_model.pb.h"
|
||||
#include "ortools/sat/cp_model_utils.h"
|
||||
#include "ortools/util/saturated_arithmetic.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
@@ -47,18 +45,18 @@ bool LiteralValue(int lit, absl::Span<const int64_t> solution) {
|
||||
|
||||
int LinearIncrementalEvaluator::NewConstraint(int64_t lb, int64_t ub) {
|
||||
const int ct_index = lower_bounds_.size();
|
||||
enforcement_literals_.resize(ct_index + 1);
|
||||
num_enforcement_literals_.push_back(0);
|
||||
lower_bounds_.push_back(lb);
|
||||
upper_bounds_.push_back(ub);
|
||||
offsets_.push_back(0);
|
||||
enforced_.push_back(true);
|
||||
num_true_enforcement_literals_.push_back(0);
|
||||
activities_.push_back(0);
|
||||
return ct_index;
|
||||
}
|
||||
|
||||
void LinearIncrementalEvaluator::AddEnforcementLiteral(int ct_index, int lit) {
|
||||
// Update the row-major storage.
|
||||
enforcement_literals_[ct_index].push_back(lit);
|
||||
num_enforcement_literals_[ct_index]++;
|
||||
|
||||
// Update the column-major storage.
|
||||
const int var = PositiveRef(lit);
|
||||
@@ -92,23 +90,16 @@ void LinearIncrementalEvaluator::AddOffset(int ct_index, int64_t offset) {
|
||||
|
||||
void LinearIncrementalEvaluator::ComputeInitialActivities(
|
||||
absl::Span<const int64_t> solution) {
|
||||
// Process enforcement literals.
|
||||
for (int ct_index = 0; ct_index < enforced_.size(); ++ct_index) {
|
||||
// Resets the activity as the offset.
|
||||
activities_[ct_index] = offsets_[ct_index];
|
||||
const int num_vars = var_entries_.size();
|
||||
const int num_constraints = num_enforcement_literals_.size();
|
||||
|
||||
// Checks if the constraint is enforced.
|
||||
enforced_[ct_index] = true;
|
||||
for (const int lit : enforcement_literals_[ct_index]) {
|
||||
if (!LiteralValue(lit, solution)) {
|
||||
enforced_[ct_index] = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Resets the activity as the offset.
|
||||
activities_ = offsets_;
|
||||
|
||||
// Updates activities from variables and coefficients.
|
||||
for (int var = 0; var < var_entries_.size(); ++var) {
|
||||
for (int var = 0; var < num_vars; ++var) {
|
||||
if (var >= var_entries_.size()) break;
|
||||
|
||||
const int64_t value = solution[var];
|
||||
if (value == 0) continue;
|
||||
for (const auto& entry : var_entries_[var]) {
|
||||
@@ -116,30 +107,33 @@ void LinearIncrementalEvaluator::ComputeInitialActivities(
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
for (int ct_index = 0; ct_index < enforced_.size(); ++ct_index) {
|
||||
VLOG(2) << "Linear(" << ct_index
|
||||
<< "): enforced = " << enforced_[ct_index]
|
||||
<< ", activities = " << activities_[ct_index];
|
||||
// Reset the num_true_enforcement_literals_.
|
||||
num_true_enforcement_literals_.assign(num_constraints, 0);
|
||||
|
||||
// Update enforcement literals count.
|
||||
for (int var = 0; var < num_vars; ++var) {
|
||||
if (var >= literal_entries_.size()) break;
|
||||
|
||||
const bool literal_is_true = solution[var] != 0;
|
||||
for (const auto& entry : literal_entries_[var]) {
|
||||
if (literal_is_true == entry.positive) {
|
||||
num_true_enforcement_literals_[entry.ct_index]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LinearIncrementalEvaluator::Update(int var, int64_t old_value,
|
||||
absl::Span<const int64_t> solution) {
|
||||
int64_t new_value) {
|
||||
DCHECK_NE(old_value, new_value);
|
||||
if (var < literal_entries_.size()) {
|
||||
const bool literal_is_true = new_value != 0;
|
||||
for (const LiteralEntry& entry : literal_entries_[var]) {
|
||||
const int ct_index = entry.ct_index;
|
||||
if ((solution[var] == 0) == entry.positive) {
|
||||
enforced_[ct_index] = false;
|
||||
if (literal_is_true == entry.positive) {
|
||||
num_true_enforcement_literals_[ct_index]++;
|
||||
} else {
|
||||
enforced_[ct_index] = true;
|
||||
for (const int lit : enforcement_literals_[ct_index]) {
|
||||
if (!LiteralValue(lit, solution)) {
|
||||
enforced_[ct_index] = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
num_true_enforcement_literals_[ct_index]--;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -147,53 +141,69 @@ void LinearIncrementalEvaluator::Update(int var, int64_t old_value,
|
||||
if (var < var_entries_.size()) {
|
||||
for (const auto& entry : var_entries_[var]) {
|
||||
activities_[entry.ct_index] +=
|
||||
entry.coefficient * (solution[var] - old_value);
|
||||
entry.coefficient * (new_value - old_value);
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
for (int ct_index = 0; ct_index < enforced_.size(); ++ct_index) {
|
||||
if (DEBUG_MODE) {
|
||||
for (int ct_index = 0; ct_index < num_enforcement_literals_.size();
|
||||
++ct_index) {
|
||||
DCHECK_GE(num_true_enforcement_literals_[ct_index], 0);
|
||||
DCHECK_LE(num_true_enforcement_literals_[ct_index],
|
||||
num_enforcement_literals_[ct_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t LinearIncrementalEvaluator::Violation(int ct_index) const {
|
||||
if (!enforced_[ct_index]) {
|
||||
VLOG(2) << "Linear(" << ct_index << "): violation = 0 as not enforced";
|
||||
if (num_true_enforcement_literals_[ct_index] <
|
||||
num_enforcement_literals_[ct_index]) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int64_t act = activities_[ct_index];
|
||||
if (act < lower_bounds_[ct_index]) {
|
||||
VLOG(2) << "Linear(" << ct_index
|
||||
<< "): violation = " << lower_bounds_[ct_index] - act;
|
||||
return lower_bounds_[ct_index] - act;
|
||||
} else if (act > upper_bounds_[ct_index]) {
|
||||
VLOG(2) << "Linear(" << ct_index
|
||||
<< "): violation = " << act - upper_bounds_[ct_index];
|
||||
return act - upper_bounds_[ct_index];
|
||||
} else {
|
||||
VLOG(2) << "Linear(" << ct_index << "): violation = 0";
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void LinearIncrementalEvaluator::ResetObjectiveBounds(int64_t lb, int64_t ub) {
|
||||
lower_bounds_[0] = lb;
|
||||
upper_bounds_[0] = ub;
|
||||
void LinearIncrementalEvaluator::ResetBounds(int ct_index, int64_t lb,
|
||||
int64_t ub) {
|
||||
lower_bounds_[ct_index] = lb;
|
||||
upper_bounds_[ct_index] = ub;
|
||||
}
|
||||
|
||||
// ----- CompiledConstraint -----
|
||||
|
||||
CompiledConstraint::CompiledConstraint(const ConstraintProto& proto)
|
||||
: proto_(proto) {}
|
||||
|
||||
// ----- CompiledLinMaxConstraint -----
|
||||
|
||||
CompiledLinMaxConstraint::CompiledLinMaxConstraint(
|
||||
const LinearArgumentProto& proto)
|
||||
: proto_(proto) {}
|
||||
// The violation of a lin_max constraint is:
|
||||
// - the sum(max(0, expr_value - target_value) forall expr)
|
||||
// - min(target_value - expr_value for all expr) if the above sum is 0
|
||||
class CompiledLinMaxConstraint : public CompiledConstraint {
|
||||
public:
|
||||
explicit CompiledLinMaxConstraint(const ConstraintProto& proto);
|
||||
~CompiledLinMaxConstraint() override = default;
|
||||
|
||||
int64_t CompiledLinMaxConstraint::Evaluate(absl::Span<const int64_t> solution) {
|
||||
const int64_t target_value = ExprValue(proto_.target(), solution);
|
||||
int64_t Violation(absl::Span<const int64_t> solution) override;
|
||||
};
|
||||
|
||||
CompiledLinMaxConstraint::CompiledLinMaxConstraint(const ConstraintProto& proto)
|
||||
: CompiledConstraint(proto) {}
|
||||
|
||||
int64_t CompiledLinMaxConstraint::Violation(
|
||||
absl::Span<const int64_t> solution) {
|
||||
const int64_t target_value = ExprValue(proto().lin_max().target(), solution);
|
||||
int64_t sum_of_excesses = 0;
|
||||
int64_t min_missing_quantities = std::numeric_limits<int64_t>::max();
|
||||
for (const LinearExpressionProto& expr : proto_.exprs()) {
|
||||
for (const LinearExpressionProto& expr : proto().lin_max().exprs()) {
|
||||
const int64_t expr_value = ExprValue(expr, solution);
|
||||
if (expr_value <= target_value) {
|
||||
min_missing_quantities =
|
||||
@@ -209,57 +219,59 @@ int64_t CompiledLinMaxConstraint::Evaluate(absl::Span<const int64_t> solution) {
|
||||
}
|
||||
}
|
||||
|
||||
void CompiledLinMaxConstraint::CallOnEachVariable(
|
||||
std::function<void(int)> func) const {
|
||||
for (const int var : proto_.target().vars()) {
|
||||
func(var);
|
||||
}
|
||||
for (const LinearExpressionProto& expr : proto_.exprs()) {
|
||||
for (const int var : expr.vars()) {
|
||||
func(var);
|
||||
}
|
||||
}
|
||||
}
|
||||
// ----- CompiledIntProdConstraint -----
|
||||
|
||||
// ----- CompiledProductConstraint -----
|
||||
// The violation of an int_prod constraint is
|
||||
// abs(value(target) - prod(value(expr)).
|
||||
class CompiledIntProdConstraint : public CompiledConstraint {
|
||||
public:
|
||||
explicit CompiledIntProdConstraint(const ConstraintProto& proto);
|
||||
~CompiledIntProdConstraint() override = default;
|
||||
|
||||
CompiledProductConstraint::CompiledProductConstraint(
|
||||
const LinearArgumentProto& proto)
|
||||
: proto_(proto) {}
|
||||
int64_t Violation(absl::Span<const int64_t> solution) override;
|
||||
};
|
||||
|
||||
int64_t CompiledProductConstraint::Evaluate(
|
||||
CompiledIntProdConstraint::CompiledIntProdConstraint(
|
||||
const ConstraintProto& proto)
|
||||
: CompiledConstraint(proto) {}
|
||||
|
||||
int64_t CompiledIntProdConstraint::Violation(
|
||||
absl::Span<const int64_t> solution) {
|
||||
const int64_t target_value = ExprValue(proto_.target(), solution);
|
||||
CHECK_EQ(proto_.exprs_size(), 2);
|
||||
const int64_t prod_value = ExprValue(proto_.exprs(0), solution) *
|
||||
ExprValue(proto_.exprs(1), solution);
|
||||
const int64_t target_value = ExprValue(proto().int_prod().target(), solution);
|
||||
CHECK_EQ(proto().int_prod().exprs_size(), 2);
|
||||
const int64_t prod_value = ExprValue(proto().int_prod().exprs(0), solution) *
|
||||
ExprValue(proto().int_prod().exprs(1), solution);
|
||||
return std::abs(target_value - prod_value);
|
||||
}
|
||||
|
||||
void CompiledProductConstraint::CallOnEachVariable(
|
||||
std::function<void(int)> func) const {
|
||||
for (const int var : proto_.target().vars()) {
|
||||
func(var);
|
||||
}
|
||||
for (const LinearExpressionProto& expr : proto_.exprs()) {
|
||||
for (const int var : expr.vars()) {
|
||||
func(var);
|
||||
}
|
||||
}
|
||||
// ----- CompiledIntDivConstraint -----
|
||||
|
||||
// The violation of an int_div constraint is
|
||||
// abs(value(target) - value(expr0) / value(expr1)).
|
||||
class CompiledIntDivConstraint : public CompiledConstraint {
|
||||
public:
|
||||
explicit CompiledIntDivConstraint(const ConstraintProto& proto);
|
||||
~CompiledIntDivConstraint() override = default;
|
||||
|
||||
int64_t Violation(absl::Span<const int64_t> solution) override;
|
||||
};
|
||||
|
||||
CompiledIntDivConstraint::CompiledIntDivConstraint(const ConstraintProto& proto)
|
||||
: CompiledConstraint(proto) {}
|
||||
|
||||
int64_t CompiledIntDivConstraint::Violation(
|
||||
absl::Span<const int64_t> solution) {
|
||||
const int64_t target_value = ExprValue(proto().int_div().target(), solution);
|
||||
CHECK_EQ(proto().int_div().exprs_size(), 2);
|
||||
const int64_t div_value = ExprValue(proto().int_div().exprs(0), solution) /
|
||||
ExprValue(proto().int_div().exprs(1), solution);
|
||||
return std::abs(target_value - div_value);
|
||||
}
|
||||
|
||||
// ----- LsEvaluator -----
|
||||
|
||||
LsEvaluator::LsEvaluator(const CpModelProto& model) : model_(model) {
|
||||
var_to_constraint_graph_.resize(model_.variables_size());
|
||||
}
|
||||
|
||||
void LsEvaluator::UpdateVariableDomains(const CpModelProto& variables_only) {
|
||||
*model_.mutable_variables() = variables_only.variables();
|
||||
CompileModel();
|
||||
}
|
||||
|
||||
void LsEvaluator::CompileModel() {
|
||||
CompileConstraintsAndObjective();
|
||||
BuildVarConstraintGraph();
|
||||
}
|
||||
@@ -271,15 +283,8 @@ void LsEvaluator::BuildVarConstraintGraph() {
|
||||
}
|
||||
|
||||
// Build the constraint graph.
|
||||
const auto collect = [this](int var) {
|
||||
if (VariableIsFixed(var)) return;
|
||||
tmp_vars_.insert(var);
|
||||
};
|
||||
|
||||
for (int ct_index = 0; ct_index < constraints_.size(); ++ct_index) {
|
||||
tmp_vars_.clear();
|
||||
constraints_[ct_index]->CallOnEachVariable(collect);
|
||||
for (const int var : tmp_vars_) {
|
||||
for (const int var : UsedVariables(constraints_[ct_index]->proto())) {
|
||||
var_to_constraint_graph_[var].push_back(ct_index);
|
||||
}
|
||||
}
|
||||
@@ -297,11 +302,7 @@ void LsEvaluator::CompileConstraintsAndObjective() {
|
||||
for (int i = 0; i < model_.objective().vars_size(); ++i) {
|
||||
const int var = model_.objective().vars(i);
|
||||
const int64_t coeff = model_.objective().coeffs(i);
|
||||
if (VariableIsFixed(var)) {
|
||||
linear_evaluator_.AddOffset(ct_index, VariableValue(var) * coeff);
|
||||
} else {
|
||||
linear_evaluator_.AddTerm(ct_index, var, coeff);
|
||||
}
|
||||
linear_evaluator_.AddTerm(ct_index, var, coeff);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,27 +347,21 @@ void LsEvaluator::CompileConstraintsAndObjective() {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ConstraintProto::ConstraintCase::kBoolXor:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kIntDiv:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kIntMod:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kLinMax: {
|
||||
CompiledLinMaxConstraint* lin_max =
|
||||
new CompiledLinMaxConstraint(ct.lin_max());
|
||||
CompiledLinMaxConstraint* lin_max = new CompiledLinMaxConstraint(ct);
|
||||
constraints_.emplace_back(lin_max);
|
||||
break;
|
||||
}
|
||||
case ConstraintProto::ConstraintCase::kIntProd: {
|
||||
CompiledProductConstraint* int_prod =
|
||||
new CompiledProductConstraint(ct.int_prod());
|
||||
CompiledIntProdConstraint* int_prod = new CompiledIntProdConstraint(ct);
|
||||
constraints_.emplace_back(int_prod);
|
||||
break;
|
||||
}
|
||||
case ConstraintProto::ConstraintCase::kIntDiv: {
|
||||
CompiledIntDivConstraint* int_div = new CompiledIntDivConstraint(ct);
|
||||
constraints_.emplace_back(int_div);
|
||||
break;
|
||||
}
|
||||
case ConstraintProto::ConstraintCase::kLinear: {
|
||||
CHECK_EQ(ct.linear().domain_size(), 2);
|
||||
const int ct_index = linear_evaluator_.NewConstraint(
|
||||
@@ -377,62 +372,20 @@ void LsEvaluator::CompileConstraintsAndObjective() {
|
||||
for (int i = 0; i < ct.linear().vars_size(); ++i) {
|
||||
const int var = ct.linear().vars(i);
|
||||
const int64_t coeff = ct.linear().coeffs(i);
|
||||
if (VariableIsFixed(var)) {
|
||||
linear_evaluator_.AddOffset(ct_index, VariableValue(var) * coeff);
|
||||
} else {
|
||||
linear_evaluator_.AddTerm(ct_index, var, coeff);
|
||||
}
|
||||
linear_evaluator_.AddTerm(ct_index, var, coeff);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case ConstraintProto::ConstraintCase::kAllDiff:
|
||||
default:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kDummyConstraint:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kElement:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kCircuit:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kRoutes:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kInverse:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kReservoir:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kTable:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kAutomaton:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kInterval:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kNoOverlap:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kNoOverlap2D:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::kCumulative:
|
||||
LOG(FATAL) << "Not implemented" << ct.constraint_case();
|
||||
break;
|
||||
case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LsEvaluator::SetObjectiveBounds(int64_t lb, int64_t ub) {
|
||||
if (!model_.has_objective()) return;
|
||||
linear_evaluator_.ResetObjectiveBounds(lb, ub);
|
||||
linear_evaluator_.ResetBounds(/*ct_index=*/0, lb, ub);
|
||||
}
|
||||
|
||||
void LsEvaluator::ComputeInitialViolations(absl::Span<const int64_t> solution) {
|
||||
@@ -443,7 +396,7 @@ void LsEvaluator::ComputeInitialViolations(absl::Span<const int64_t> solution) {
|
||||
|
||||
// Generic constraints.
|
||||
for (const auto& ct : constraints_) {
|
||||
const int64_t ct_eval = ct->Evaluate(solution);
|
||||
const int64_t ct_eval = ct->Violation(solution);
|
||||
ct->clear_violations();
|
||||
ct->push_violation(ct_eval);
|
||||
}
|
||||
@@ -456,9 +409,10 @@ void LsEvaluator::UpdateVariableValueAndRecomputeViolations(int var,
|
||||
if (old_value == new_value) return;
|
||||
|
||||
current_solution_[var] = new_value;
|
||||
linear_evaluator_.Update(var, old_value, current_solution_);
|
||||
linear_evaluator_.Update(var, old_value, new_value);
|
||||
for (const int ct_index : var_to_constraint_graph_[var]) {
|
||||
const int64_t ct_eval = constraints_[ct_index]->Evaluate(current_solution_);
|
||||
const int64_t ct_eval =
|
||||
constraints_[ct_index]->Violation(current_solution_);
|
||||
constraints_[ct_index]->clear_violations();
|
||||
constraints_[ct_index]->push_violation(ct_eval);
|
||||
}
|
||||
@@ -468,9 +422,8 @@ int64_t LsEvaluator::SumOfViolations() {
|
||||
int64_t evaluation = 0;
|
||||
|
||||
// Process the linear part.
|
||||
for (int lin_index = 0; lin_index < linear_evaluator_.num_constraints();
|
||||
++lin_index) {
|
||||
evaluation += linear_evaluator_.Violation(lin_index);
|
||||
for (int i = 0; i < linear_evaluator_.num_constraints(); ++i) {
|
||||
evaluation += linear_evaluator_.Violation(i);
|
||||
}
|
||||
|
||||
// Process the generic constraint part.
|
||||
@@ -480,44 +433,5 @@ int64_t LsEvaluator::SumOfViolations() {
|
||||
return evaluation;
|
||||
}
|
||||
|
||||
bool LsEvaluator::VariableIsFixed(int ref) const {
|
||||
const int var = PositiveRef(ref);
|
||||
const IntegerVariableProto& var_proto = model_.variables(var);
|
||||
return var_proto.domain_size() == 2 &&
|
||||
var_proto.domain(0) == var_proto.domain(1);
|
||||
}
|
||||
|
||||
int64_t LsEvaluator::VariableMin(int var) const {
|
||||
CHECK(RefIsPositive(var));
|
||||
const IntegerVariableProto& var_proto = model_.variables(var);
|
||||
return var_proto.domain(0);
|
||||
}
|
||||
|
||||
int64_t LsEvaluator::VariableMax(int var) const {
|
||||
CHECK(RefIsPositive(var));
|
||||
const IntegerVariableProto& var_proto = model_.variables(var);
|
||||
return var_proto.domain(var_proto.domain_size() - 1);
|
||||
}
|
||||
|
||||
int64_t LsEvaluator::VariableValue(int var) const {
|
||||
CHECK(VariableIsFixed(var));
|
||||
return VariableMin(var);
|
||||
}
|
||||
|
||||
bool LsEvaluator::LiteralValue(int lit) const {
|
||||
CHECK(VariableIsFixed(lit));
|
||||
return RefIsPositive(lit) == (VariableValue(PositiveRef(lit)) == 1);
|
||||
}
|
||||
|
||||
bool LsEvaluator::LiteralIsFalse(int lit) const {
|
||||
if (!VariableIsFixed(lit)) return false;
|
||||
return RefIsPositive(lit) == (VariableValue(PositiveRef(lit)) == 0);
|
||||
}
|
||||
|
||||
bool LsEvaluator::LiteralIsTrue(int lit) const {
|
||||
if (!VariableIsFixed(lit)) return false;
|
||||
return RefIsPositive(lit) == (VariableValue(PositiveRef(lit)) == 1);
|
||||
}
|
||||
|
||||
} // namespace sat
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -15,35 +15,20 @@
|
||||
#define OR_TOOLS_SAT_CONSTRAINT_VIOLATION_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "ortools/sat/cp_model.pb.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace sat {
|
||||
|
||||
class LsEvaluator;
|
||||
|
||||
bool LiteralValue(int lit, absl::Span<const int64_t> solution);
|
||||
int64_t ExprValue(const LinearExpressionProto& expr,
|
||||
absl ::Span<const int64_t> solution);
|
||||
absl::Span<const int64_t> solution);
|
||||
|
||||
class LinearIncrementalEvaluator {
|
||||
public:
|
||||
struct Entry {
|
||||
int ct_index;
|
||||
int64_t coefficient;
|
||||
};
|
||||
|
||||
struct LiteralEntry {
|
||||
int ct_index;
|
||||
bool positive;
|
||||
};
|
||||
|
||||
LinearIncrementalEvaluator() = default;
|
||||
|
||||
// Returns the index of the new constraint.
|
||||
@@ -57,20 +42,30 @@ class LinearIncrementalEvaluator {
|
||||
|
||||
// Compute activities and query violations.
|
||||
void ComputeInitialActivities(absl::Span<const int64_t> solution);
|
||||
void Update(int var, int64_t old_value, absl::Span<const int64_t> solution);
|
||||
void Update(int var, int64_t old_value, int64_t new_value);
|
||||
int64_t Violation(int ct_index) const;
|
||||
|
||||
// Manage the objective.
|
||||
void ResetObjectiveBounds(int64_t lb, int64_t ub);
|
||||
// Update constraint bounds.
|
||||
void ResetBounds(int ct_index, int64_t lb, int64_t ub);
|
||||
|
||||
// Model getters.
|
||||
int num_constraints() const { return activities_.size(); }
|
||||
|
||||
private:
|
||||
// Cell in the sparse matrix.
|
||||
struct Entry {
|
||||
int ct_index;
|
||||
int64_t coefficient;
|
||||
};
|
||||
|
||||
// Column-view of the enforcement literals.
|
||||
struct LiteralEntry {
|
||||
int ct_index;
|
||||
bool positive; // bool_var or its negation.
|
||||
};
|
||||
|
||||
// Model data.
|
||||
// TODO(user): We should store the constraint proto here instead of the
|
||||
// enforcement literals. Just a bit problematic with the objective.
|
||||
std::vector<std::vector<int>> enforcement_literals_;
|
||||
std::vector<int> num_enforcement_literals_;
|
||||
std::vector<int64_t> lower_bounds_;
|
||||
std::vector<int64_t> upper_bounds_;
|
||||
|
||||
@@ -81,20 +76,20 @@ class LinearIncrementalEvaluator {
|
||||
|
||||
// Dynamic data.
|
||||
std::vector<int64_t> activities_;
|
||||
std::vector<bool> enforced_;
|
||||
std::vector<int> num_true_enforcement_literals_;
|
||||
};
|
||||
|
||||
// View of a generic (non linear) constraint for the LsEvaluator.
|
||||
class CompiledConstraint {
|
||||
public:
|
||||
explicit CompiledConstraint() = default;
|
||||
explicit CompiledConstraint(const ConstraintProto& proto);
|
||||
virtual ~CompiledConstraint() = default;
|
||||
|
||||
// Evaluation.
|
||||
virtual int64_t Evaluate(absl::Span<const int64_t> solution) = 0;
|
||||
|
||||
// Utilities to compute the var <-> constraint graph.
|
||||
virtual void CallOnEachVariable(std::function<void(int)> func) const = 0;
|
||||
// Computes the violation of a constraint.
|
||||
//
|
||||
// A violation is a positive integer value. A zero value means the constraint
|
||||
// is not violated..
|
||||
virtual int64_t Violation(absl::Span<const int64_t> solution) = 0;
|
||||
|
||||
// Violations are stored in a stack for each constraint.
|
||||
int64_t current_violation() const { return violations_.back(); }
|
||||
@@ -102,50 +97,19 @@ class CompiledConstraint {
|
||||
void pop_violation() { violations_.pop_back(); }
|
||||
void clear_violations() { violations_.clear(); }
|
||||
|
||||
const ConstraintProto& proto() const { return proto_; }
|
||||
|
||||
private:
|
||||
const ConstraintProto& proto_;
|
||||
std::vector<int64_t> violations_;
|
||||
};
|
||||
|
||||
// The violation of a lin_max constraint is:
|
||||
// - the sum(max(0, expr_value - target_value) forall expr)
|
||||
// - min(target_value - expr_value for all expr) if the above sum is 0
|
||||
class CompiledLinMaxConstraint : public CompiledConstraint {
|
||||
public:
|
||||
explicit CompiledLinMaxConstraint(const LinearArgumentProto& proto);
|
||||
~CompiledLinMaxConstraint() override = default;
|
||||
|
||||
int64_t Evaluate(absl::Span<const int64_t> solution) override;
|
||||
void CallOnEachVariable(std::function<void(int)> func) const override;
|
||||
|
||||
private:
|
||||
const LinearArgumentProto& proto_;
|
||||
};
|
||||
|
||||
// The violation of an int_prod constraint is
|
||||
// abs(target_value - prod(expr value)).
|
||||
class CompiledProductConstraint : public CompiledConstraint {
|
||||
public:
|
||||
explicit CompiledProductConstraint(const LinearArgumentProto& proto);
|
||||
~CompiledProductConstraint() override = default;
|
||||
|
||||
int64_t Evaluate(absl::Span<const int64_t> solution) override;
|
||||
void CallOnEachVariable(std::function<void(int)> func) const override;
|
||||
|
||||
private:
|
||||
const LinearArgumentProto& proto_;
|
||||
};
|
||||
|
||||
// Evaluation container for the local search.
|
||||
class LsEvaluator {
|
||||
public:
|
||||
// The model must outlive this class.
|
||||
explicit LsEvaluator(const CpModelProto& model);
|
||||
|
||||
// Assigns the variable domains from variables_only to the current storage.
|
||||
void UpdateVariableDomains(const CpModelProto& variables_only);
|
||||
|
||||
// Compiles the current model into a efficient representation.
|
||||
void CompileModel();
|
||||
|
||||
// Overwrites the bounds of the objective.
|
||||
void SetObjectiveBounds(int64_t lb, int64_t ub);
|
||||
|
||||
@@ -158,25 +122,15 @@ class LsEvaluator {
|
||||
// Simple summation metric for the constraint and objective violations.
|
||||
int64_t SumOfViolations();
|
||||
|
||||
// Getters for variables using the domains set with UpdateVariableDomains().
|
||||
bool VariableIsFixed(int ref) const;
|
||||
int64_t VariableMin(int var) const;
|
||||
int64_t VariableMax(int var) const;
|
||||
int64_t VariableValue(int var) const;
|
||||
bool LiteralValue(int lit) const;
|
||||
bool LiteralIsFalse(int lit) const;
|
||||
bool LiteralIsTrue(int lit) const;
|
||||
|
||||
private:
|
||||
void CompileConstraintsAndObjective();
|
||||
void BuildVarConstraintGraph();
|
||||
|
||||
CpModelProto model_;
|
||||
const CpModelProto& model_;
|
||||
LinearIncrementalEvaluator linear_evaluator_;
|
||||
std::vector<std::unique_ptr<CompiledConstraint>> constraints_;
|
||||
std::vector<std::vector<int>> var_to_constraint_graph_;
|
||||
std::vector<int64_t> current_solution_;
|
||||
absl::flat_hash_set<int> tmp_vars_;
|
||||
};
|
||||
|
||||
} // namespace sat
|
||||
|
||||
Reference in New Issue
Block a user