Files
ortools-clone/ortools/sat/solution_crush.cc

745 lines
26 KiB
C++

// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/sat/solution_crush.h"
#include <algorithm>
#include <cstdint>
#include <memory>
#include <optional>
#ifdef CHECK_CRUSH
#include <sstream>
#include <string>
#endif
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/types/span.h"
#include "ortools/algorithms/sparse_permutation.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/cp_model_utils.h"
#include "ortools/sat/diffn_util.h"
#include "ortools/sat/sat_parameters.pb.h"
#include "ortools/sat/symmetry_util.h"
#include "ortools/sat/util.h"
#include "ortools/util/sorted_interval_list.h"
namespace operations_research {
namespace sat {
void SolutionCrush::LoadSolution(
int num_vars, const absl::flat_hash_map<int, int64_t>& solution) {
CHECK(!solution_is_loaded_);
CHECK(var_has_value_.empty());
CHECK(var_values_.empty());
solution_is_loaded_ = true;
var_has_value_.resize(num_vars, false);
var_values_.resize(num_vars, 0);
for (const auto [var, value] : solution) {
var_has_value_[var] = true;
var_values_[var] = value;
}
}
void SolutionCrush::Resize(int new_size) {
if (!solution_is_loaded_) return;
var_has_value_.resize(new_size, false);
var_values_.resize(new_size, 0);
}
void SolutionCrush::MaybeSetLiteralToValueEncoding(int literal, int var,
int64_t value) {
DCHECK(RefIsPositive(var));
if (!solution_is_loaded_) return;
if (!HasValue(PositiveRef(literal)) && HasValue(var)) {
SetLiteralValue(literal, GetVarValue(var) == value);
}
}
void SolutionCrush::SetVarToLinearExpression(
int new_var, absl::Span<const std::pair<int, int64_t>> linear,
int64_t offset) {
if (!solution_is_loaded_) return;
int64_t new_value = offset;
for (const auto [var, coeff] : linear) {
if (!HasValue(var)) return;
new_value += coeff * GetVarValue(var);
}
SetVarValue(new_var, new_value);
}
void SolutionCrush::SetVarToLinearExpression(int new_var,
absl::Span<const int> vars,
absl::Span<const int64_t> coeffs,
int64_t offset) {
DCHECK_EQ(vars.size(), coeffs.size());
if (!solution_is_loaded_) return;
int64_t new_value = offset;
for (int i = 0; i < vars.size(); ++i) {
const int var = vars[i];
const int64_t coeff = coeffs[i];
if (!HasValue(var)) return;
new_value += coeff * GetVarValue(var);
}
SetVarValue(new_var, new_value);
}
void SolutionCrush::SetVarToClause(int new_var, absl::Span<const int> clause) {
if (!solution_is_loaded_) return;
int new_value = 0;
bool all_have_value = true;
for (const int literal : clause) {
const int var = PositiveRef(literal);
if (!HasValue(var)) {
all_have_value = false;
break;
}
if (GetVarValue(var) == (RefIsPositive(literal) ? 1 : 0)) {
new_value = 1;
break;
}
}
// Leave the `new_var` unassigned if any literal is unassigned.
if (all_have_value) {
SetVarValue(new_var, new_value);
}
}
void SolutionCrush::SetVarToConjunction(int new_var,
absl::Span<const int> conjunction) {
if (!solution_is_loaded_) return;
int new_value = 1;
bool all_have_value = true;
for (const int literal : conjunction) {
const int var = PositiveRef(literal);
if (!HasValue(var)) {
all_have_value = false;
break;
}
if (GetVarValue(var) == (RefIsPositive(literal) ? 0 : 1)) {
new_value = 0;
break;
}
}
// Leave the `new_var` unassigned if any literal is unassigned.
if (all_have_value) {
SetVarValue(new_var, new_value);
}
}
void SolutionCrush::SetVarToValueIfLinearConstraintViolated(
int new_var, int64_t value,
absl::Span<const std::pair<int, int64_t>> linear, const Domain& domain) {
if (!solution_is_loaded_) return;
int64_t linear_value = 0;
bool all_have_value = true;
for (const auto [var, coeff] : linear) {
if (!HasValue(var)) {
all_have_value = false;
break;
}
linear_value += GetVarValue(var) * coeff;
}
if (all_have_value && !domain.Contains(linear_value)) {
SetVarValue(new_var, value);
}
}
void SolutionCrush::SetLiteralToValueIfLinearConstraintViolated(
int literal, bool value, absl::Span<const std::pair<int, int64_t>> linear,
const Domain& domain) {
SetVarToValueIfLinearConstraintViolated(
PositiveRef(literal), RefIsPositive(literal) ? value : !value, linear,
domain);
}
void SolutionCrush::SetVarToValueIf(int var, int64_t value, int condition_lit) {
SetVarToValueIfLinearConstraintViolated(
var, value, {{PositiveRef(condition_lit), 1}},
Domain(RefIsPositive(condition_lit) ? 0 : 1));
}
void SolutionCrush::SetVarToLinearExpressionIf(
int var, const LinearExpressionProto& expr, int condition_lit) {
if (!solution_is_loaded_) return;
if (!HasValue(PositiveRef(condition_lit))) return;
if (!GetLiteralValue(condition_lit)) return;
const std::optional<int64_t> expr_value = GetExpressionValue(expr);
if (expr_value.has_value()) {
SetVarValue(var, expr_value.value());
}
}
void SolutionCrush::SetLiteralToValueIf(int literal, bool value,
int condition_lit) {
SetLiteralToValueIfLinearConstraintViolated(
literal, value, {{PositiveRef(condition_lit), 1}},
Domain(RefIsPositive(condition_lit) ? 0 : 1));
}
void SolutionCrush::SetVarToConditionalValue(
int var, absl::Span<const int> condition_lits, int64_t value_if_true,
int64_t value_if_false) {
if (!solution_is_loaded_) return;
bool condition_value = true;
for (const int condition_lit : condition_lits) {
if (!HasValue(PositiveRef(condition_lit))) return;
if (!GetLiteralValue(condition_lit)) {
condition_value = false;
break;
}
}
SetVarValue(var, condition_value ? value_if_true : value_if_false);
}
void SolutionCrush::MakeLiteralsEqual(int lit1, int lit2) {
if (!solution_is_loaded_) return;
if (HasValue(PositiveRef(lit2))) {
SetLiteralValue(lit1, GetLiteralValue(lit2));
} else if (HasValue(PositiveRef(lit1))) {
SetLiteralValue(lit2, GetLiteralValue(lit1));
}
}
void SolutionCrush::SetOrUpdateVarToDomain(int var, const Domain& domain) {
if (!solution_is_loaded_) return;
if (HasValue(var)) {
SetVarValue(var, domain.ClosestValue(GetVarValue(var)));
} else if (domain.IsFixed()) {
SetVarValue(var, domain.FixedValue());
}
}
void SolutionCrush::UpdateLiteralsToFalseIfDifferent(int lit1, int lit2) {
// Set lit1 and lit2 to false if "lit1 - lit2 == 0" is violated.
const int sign1 = RefIsPositive(lit1) ? 1 : -1;
const int sign2 = RefIsPositive(lit2) ? 1 : -1;
const std::vector<std::pair<int, int64_t>> linear = {
{PositiveRef(lit1), sign1}, {PositiveRef(lit2), -sign2}};
const Domain domain = Domain((sign1 == 1 ? 0 : -1) - (sign2 == 1 ? 0 : -1));
SetLiteralToValueIfLinearConstraintViolated(lit1, false, linear, domain);
SetLiteralToValueIfLinearConstraintViolated(lit2, false, linear, domain);
}
void SolutionCrush::UpdateLiteralsWithDominance(int lit, int dominating_lit) {
if (!solution_is_loaded_) return;
if (!HasValue(PositiveRef(lit)) || !HasValue(PositiveRef(dominating_lit))) {
return;
}
if (GetLiteralValue(lit) && !GetLiteralValue(dominating_lit)) {
SetLiteralValue(lit, false);
SetLiteralValue(dominating_lit, true);
}
}
void SolutionCrush::MaybeUpdateVarWithSymmetriesToValue(
int var, bool value,
absl::Span<const std::unique_ptr<SparsePermutation>> generators) {
if (!solution_is_loaded_) return;
if (!HasValue(var)) return;
if (GetVarValue(var) == static_cast<int64_t>(value)) return;
std::vector<int> schrier_vector;
std::vector<int> orbit;
GetSchreierVectorAndOrbit(var, generators, &schrier_vector, &orbit);
bool found_target = false;
int target_var;
for (int v : orbit) {
if (HasValue(v) && GetVarValue(v) == static_cast<int64_t>(value)) {
found_target = true;
target_var = v;
break;
}
}
if (!found_target) {
VLOG(1) << "Couldn't transform solution properly";
return;
}
const std::vector<int> generator_idx =
TracePoint(target_var, schrier_vector, generators);
for (const int i : generator_idx) {
PermuteVariables(*generators[i]);
}
DCHECK(HasValue(var));
DCHECK_EQ(GetVarValue(var), value);
}
void SolutionCrush::MaybeSwapOrbitopeColumns(
absl::Span<const std::vector<int>> orbitope, int row, int pivot_col,
bool value) {
if (!solution_is_loaded_) return;
int col = -1;
for (int c = 0; c < orbitope[row].size(); ++c) {
if (GetLiteralValue(orbitope[row][c]) == value) {
if (col != -1) {
VLOG(2) << "Multiple literals in row with given value";
return;
}
col = c;
}
}
if (col < pivot_col) {
// Nothing to do.
return;
}
// Swap the value of the literals in column `col` with the value of the ones
// in column `pivot_col`, if they all have a value.
for (int i = 0; i < orbitope.size(); ++i) {
if (!HasValue(PositiveRef(orbitope[i][col]))) return;
if (!HasValue(PositiveRef(orbitope[i][pivot_col]))) return;
}
for (int i = 0; i < orbitope.size(); ++i) {
const int src_lit = orbitope[i][col];
const int dst_lit = orbitope[i][pivot_col];
const bool src_value = GetLiteralValue(src_lit);
const bool dst_value = GetLiteralValue(dst_lit);
SetLiteralValue(src_lit, dst_value);
SetLiteralValue(dst_lit, src_value);
}
}
void SolutionCrush::UpdateRefsWithDominance(
int ref, int64_t min_value, int64_t max_value,
absl::Span<const std::pair<int, Domain>> dominating_refs) {
if (!solution_is_loaded_) return;
const std::optional<int64_t> ref_value = GetRefValue(ref);
if (!ref_value.has_value()) return;
// This can happen if the solution is not initially feasible (in which
// case we can't fix it).
if (*ref_value < min_value) return;
// If the value is already in the new domain there is nothing to do.
if (*ref_value <= max_value) return;
// The quantity to subtract from the value of `ref`.
const int64_t ref_value_delta = *ref_value - max_value;
SetRefValue(ref, *ref_value - ref_value_delta);
int64_t remaining_delta = ref_value_delta;
for (const auto& [dominating_ref, dominating_ref_domain] : dominating_refs) {
const std::optional<int64_t> dominating_ref_value =
GetRefValue(dominating_ref);
if (!dominating_ref_value.has_value()) continue;
const int64_t new_dominating_ref_value =
dominating_ref_domain.ValueAtOrBefore(*dominating_ref_value +
remaining_delta);
// This might happen if the solution is not initially feasible.
if (!dominating_ref_domain.Contains(new_dominating_ref_value)) continue;
SetRefValue(dominating_ref, new_dominating_ref_value);
remaining_delta -= (new_dominating_ref_value - *dominating_ref_value);
if (remaining_delta == 0) break;
}
}
void SolutionCrush::SetVarToLinearConstraintSolution(
std::optional<int> var_index, absl::Span<const int> vars,
absl::Span<const int64_t> coeffs, int64_t rhs) {
DCHECK_EQ(vars.size(), coeffs.size());
DCHECK(!var_index.has_value() || var_index.value() < vars.size());
if (!solution_is_loaded_) return;
int64_t term_value = rhs;
for (int i = 0; i < vars.size(); ++i) {
if (HasValue(vars[i])) {
if (i != var_index) {
term_value -= GetVarValue(vars[i]) * coeffs[i];
}
} else if (!var_index.has_value()) {
var_index = i;
} else if (var_index.value() != i) {
return;
}
}
if (!var_index.has_value()) return;
SetVarValue(vars[var_index.value()], term_value / coeffs[var_index.value()]);
#ifdef CHECK_CRUSH
if (term_value % coeffs[var_index.value()] != 0) {
std::stringstream lhs;
for (int i = 0; i < vars.size(); ++i) {
lhs << (i == var_index ? "x" : std::to_string(GetVarValue(vars[i])));
lhs << " * " << coeffs[i];
if (i < vars.size() - 1) lhs << " + ";
}
LOG(FATAL) << "Linear constraint incompatible with solution: " << lhs
<< " != " << rhs;
}
#endif
}
void SolutionCrush::SetReservoirCircuitVars(
const ReservoirConstraintProto& reservoir, int64_t min_level,
int64_t max_level, absl::Span<const int> level_vars,
const CircuitConstraintProto& circuit) {
if (!solution_is_loaded_) return;
// The values of the active events, in the order they should appear in the
// circuit. The values are collected first, and sorted later.
struct ReservoirEventValues {
int index; // In the reservoir constraint.
int64_t time;
int64_t level_change;
};
const int num_events = reservoir.time_exprs_size();
std::vector<ReservoirEventValues> active_event_values;
for (int i = 0; i < num_events; ++i) {
if (!HasValue(PositiveRef(reservoir.active_literals(i)))) return;
if (GetLiteralValue(reservoir.active_literals(i))) {
const std::optional<int64_t> time_value =
GetExpressionValue(reservoir.time_exprs(i));
const std::optional<int64_t> change_value =
GetExpressionValue(reservoir.level_changes(i));
if (!time_value.has_value() || !change_value.has_value()) return;
active_event_values.push_back(
{i, time_value.value(), change_value.value()});
}
}
// Update the `level_vars` values by computing the level at each active event.
std::sort(active_event_values.begin(), active_event_values.end(),
[](const ReservoirEventValues& a, const ReservoirEventValues& b) {
return a.time < b.time;
});
int64_t current_level = 0;
for (int i = 0; i < active_event_values.size(); ++i) {
int j = i;
// Adjust the order of the events occurring at the same time, in the
// circuit, so that, at each node, the level is between `var_min` and
// `var_max`. For instance, if e1 = {t, +1} and e2 = {t, -1}, and if
// `current_level` = 0, `var_min` = -1 and `var_max` = 0, then e2 must occur
// before e1.
while (j < active_event_values.size() &&
active_event_values[j].time == active_event_values[i].time &&
(current_level + active_event_values[j].level_change < min_level ||
current_level + active_event_values[j].level_change > max_level)) {
++j;
}
if (j < active_event_values.size() &&
active_event_values[j].time == active_event_values[i].time) {
if (i != j) std::swap(active_event_values[i], active_event_values[j]);
current_level += active_event_values[i].level_change;
SetVarValue(level_vars[active_event_values[i].index], current_level);
} else {
return;
}
}
// The index of each event in `active_event_values`, or -1 if the event's
// "active" value is false.
std::vector<int> active_event_value_index(num_events, -1);
for (int i = 0; i < active_event_values.size(); ++i) {
active_event_value_index[active_event_values[i].index] = i;
}
for (int i = 0; i < circuit.literals_size(); ++i) {
const int head = circuit.heads(i);
const int tail = circuit.tails(i);
const int literal = circuit.literals(i);
if (tail == num_events) {
if (head == num_events) {
// Self-arc on the start and end node.
SetLiteralValue(literal, active_event_values.empty());
} else {
// Arc from the start node to an event node.
SetLiteralValue(literal, !active_event_values.empty() &&
active_event_values.front().index == head);
}
} else if (head == num_events) {
// Arc from an event node to the end node.
SetLiteralValue(literal, !active_event_values.empty() &&
active_event_values.back().index == tail);
} else if (tail != head) {
// Arc between two different event nodes.
const int tail_index = active_event_value_index[tail];
const int head_index = active_event_value_index[head];
SetLiteralValue(literal, tail_index != -1 && tail_index != -1 &&
head_index == tail_index + 1);
}
}
}
void SolutionCrush::SetVarToReifiedPrecedenceLiteral(
int var, const LinearExpressionProto& time_i,
const LinearExpressionProto& time_j, int active_i, int active_j) {
if (!solution_is_loaded_) return;
std::optional<int64_t> time_i_value = GetExpressionValue(time_i);
std::optional<int64_t> time_j_value = GetExpressionValue(time_j);
std::optional<int64_t> active_i_value = GetRefValue(active_i);
std::optional<int64_t> active_j_value = GetRefValue(active_j);
if (time_i_value.has_value() && time_j_value.has_value() &&
active_i_value.has_value() && active_j_value.has_value()) {
const bool reified_value = (active_i_value.value() != 0) &&
(active_j_value.value() != 0) &&
(time_i_value.value() <= time_j_value.value());
SetVarValue(var, reified_value);
}
}
void SolutionCrush::SetIntModExpandedVars(const ConstraintProto& ct,
int div_var, int prod_var,
int64_t default_div_value,
int64_t default_prod_value) {
if (!solution_is_loaded_) return;
bool enforced_value = true;
for (const int lit : ct.enforcement_literal()) {
if (!HasValue(PositiveRef(lit))) return;
enforced_value = enforced_value && GetLiteralValue(lit);
}
if (!enforced_value) {
SetVarValue(div_var, default_div_value);
SetVarValue(prod_var, default_prod_value);
return;
}
const LinearArgumentProto& int_mod = ct.int_mod();
std::optional<int64_t> v = GetExpressionValue(int_mod.exprs(0));
if (!v.has_value()) return;
const int64_t expr_value = v.value();
v = GetExpressionValue(int_mod.exprs(1));
if (!v.has_value()) return;
const int64_t mod_expr_value = v.value();
v = GetExpressionValue(int_mod.target());
if (!v.has_value()) return;
const int64_t target_expr_value = v.value();
// target_expr_value should be equal to "expr_value % mod_expr_value".
SetVarValue(div_var, expr_value / mod_expr_value);
SetVarValue(prod_var, expr_value - target_expr_value);
}
void SolutionCrush::SetIntProdExpandedVars(const LinearArgumentProto& int_prod,
absl::Span<const int> prod_vars) {
DCHECK_EQ(int_prod.exprs_size(), prod_vars.size() + 2);
if (!solution_is_loaded_) return;
std::optional<int64_t> v = GetExpressionValue(int_prod.exprs(0));
if (!v.has_value()) return;
int64_t last_prod_value = v.value();
for (int i = 1; i < int_prod.exprs_size() - 1; ++i) {
v = GetExpressionValue(int_prod.exprs(i));
if (!v.has_value()) return;
last_prod_value *= v.value();
SetVarValue(prod_vars[i - 1], last_prod_value);
}
}
void SolutionCrush::SetLinMaxExpandedVars(
const LinearArgumentProto& lin_max,
absl::Span<const int> enforcement_lits) {
if (!solution_is_loaded_) return;
DCHECK_EQ(enforcement_lits.size(), lin_max.exprs_size());
const std::optional<int64_t> target_value =
GetExpressionValue(lin_max.target());
if (!target_value.has_value()) return;
int enforcement_value_sum = 0;
for (int i = 0; i < enforcement_lits.size(); ++i) {
const std::optional<int64_t> expr_value =
GetExpressionValue(lin_max.exprs(i));
if (!expr_value.has_value()) return;
if (enforcement_value_sum == 0) {
const bool enforcement_value = target_value.value() <= expr_value.value();
SetLiteralValue(enforcement_lits[i], enforcement_value);
enforcement_value_sum += enforcement_value;
} else {
SetLiteralValue(enforcement_lits[i], false);
}
}
}
void SolutionCrush::SetAutomatonExpandedVars(
const AutomatonConstraintProto& automaton,
absl::Span<const StateVar> state_vars,
absl::Span<const TransitionVar> transition_vars) {
if (!solution_is_loaded_) return;
absl::flat_hash_map<std::pair<int64_t, int64_t>, int64_t> transitions;
for (int i = 0; i < automaton.transition_tail_size(); ++i) {
transitions[{automaton.transition_tail(i), automaton.transition_label(i)}] =
automaton.transition_head(i);
}
std::vector<int64_t> label_values;
std::vector<int64_t> state_values;
int64_t current_state = automaton.starting_state();
state_values.push_back(current_state);
for (int i = 0; i < automaton.exprs_size(); ++i) {
const std::optional<int64_t> label_value =
GetExpressionValue(automaton.exprs(i));
if (!label_value.has_value()) return;
label_values.push_back(label_value.value());
const auto it = transitions.find({current_state, label_value.value()});
if (it == transitions.end()) return;
current_state = it->second;
state_values.push_back(current_state);
}
for (const auto& [var, time, state] : state_vars) {
SetVarValue(var, state_values[time] == state);
}
for (const auto& [var, time, transition_tail, transition_label] :
transition_vars) {
SetVarValue(var, state_values[time] == transition_tail &&
label_values[time] == transition_label);
}
}
void SolutionCrush::SetTableExpandedVars(
absl::Span<const int> column_vars, absl::Span<const int> existing_row_lits,
absl::Span<const TableRowLiteral> new_row_lits) {
if (!solution_is_loaded_) return;
int row_lit_values_sum = 0;
for (const int lit : existing_row_lits) {
if (!HasValue(PositiveRef(lit))) return;
row_lit_values_sum += GetLiteralValue(lit);
}
const int num_vars = column_vars.size();
for (const auto& [lit, var_values] : new_row_lits) {
if (row_lit_values_sum >= 1) {
SetLiteralValue(lit, false);
continue;
}
bool row_lit_value = true;
for (int var_index = 0; var_index < num_vars; ++var_index) {
const auto& values = var_values[var_index];
if (!values.empty() &&
std::find(values.begin(), values.end(),
GetVarValue(column_vars[var_index])) == values.end()) {
row_lit_value = false;
break;
}
}
SetLiteralValue(lit, row_lit_value);
row_lit_values_sum += row_lit_value;
}
}
void SolutionCrush::SetLinearWithComplexDomainExpandedVars(
const LinearConstraintProto& linear, absl::Span<const int> bucket_lits) {
if (!solution_is_loaded_) return;
int64_t expr_value = 0;
for (int i = 0; i < linear.vars_size(); ++i) {
const int var = linear.vars(i);
if (!HasValue(var)) return;
expr_value += linear.coeffs(i) * GetVarValue(var);
}
DCHECK_LE(bucket_lits.size(), linear.domain_size() / 2);
for (int i = 0; i < bucket_lits.size(); ++i) {
const int64_t lb = linear.domain(2 * i);
const int64_t ub = linear.domain(2 * i + 1);
SetLiteralValue(bucket_lits[i], expr_value >= lb && expr_value <= ub);
}
}
void SolutionCrush::StoreSolutionAsHint(CpModelProto& model) const {
if (!solution_is_loaded_) return;
model.clear_solution_hint();
for (int i = 0; i < var_values_.size(); ++i) {
if (var_has_value_[i]) {
model.mutable_solution_hint()->add_vars(i);
model.mutable_solution_hint()->add_values(var_values_[i]);
}
}
}
void SolutionCrush::PermuteVariables(const SparsePermutation& permutation) {
CHECK(solution_is_loaded_);
permutation.ApplyToDenseCollection(var_has_value_);
permutation.ApplyToDenseCollection(var_values_);
}
void SolutionCrush::AssignVariableToPackingArea(
const CompactVectorVector<int, Rectangle>& areas, const CpModelProto& model,
absl::Span<const int> x_intervals, absl::Span<const int> y_intervals,
absl::Span<const BoxInAreaLiteral> box_in_area_lits) {
if (!solution_is_loaded_) return;
struct RectangleTypeAndIndex {
enum class Type {
kHintedBox,
kArea,
};
int index;
Type type;
};
std::vector<Rectangle> rectangles_for_intersections;
std::vector<RectangleTypeAndIndex> rectangles_index;
for (int i = 0; i < x_intervals.size(); ++i) {
const ConstraintProto& x_ct = model.constraints(x_intervals[i]);
const ConstraintProto& y_ct = model.constraints(y_intervals[i]);
const std::optional<int64_t> x_min =
GetExpressionValue(x_ct.interval().start());
const std::optional<int64_t> x_max =
GetExpressionValue(x_ct.interval().end());
const std::optional<int64_t> y_min =
GetExpressionValue(y_ct.interval().start());
const std::optional<int64_t> y_max =
GetExpressionValue(y_ct.interval().end());
if (!x_min.has_value() || !x_max.has_value() || !y_min.has_value() ||
!y_max.has_value()) {
return;
}
if (*x_min > *x_max || *y_min > *y_max) {
VLOG(2) << "Hinted no_overlap_2d coordinate has max lower than min";
return;
}
const Rectangle box = {.x_min = x_min.value(),
.x_max = x_max.value(),
.y_min = y_min.value(),
.y_max = y_max.value()};
rectangles_for_intersections.push_back(box);
rectangles_index.push_back(
{.index = i, .type = RectangleTypeAndIndex::Type::kHintedBox});
}
for (int i = 0; i < areas.size(); ++i) {
for (const Rectangle& area : areas[i]) {
rectangles_for_intersections.push_back(area);
rectangles_index.push_back(
{.index = i, .type = RectangleTypeAndIndex::Type::kArea});
}
}
const std::vector<std::pair<int, int>> intersections =
FindPartialRectangleIntersections(rectangles_for_intersections);
absl::flat_hash_set<std::pair<int, int>> box_to_area_pairs;
for (const auto& [rec1_index, rec2_index] : intersections) {
RectangleTypeAndIndex rec1 = rectangles_index[rec1_index];
RectangleTypeAndIndex rec2 = rectangles_index[rec2_index];
if (rec1.type == rec2.type) {
DCHECK(rec1.type == RectangleTypeAndIndex::Type::kHintedBox);
VLOG(2) << "Hinted position of boxes in no_overlap_2d are overlapping";
return;
}
if (rec1.type != RectangleTypeAndIndex::Type::kHintedBox) {
std::swap(rec1, rec2);
}
box_to_area_pairs.insert({rec1.index, rec2.index});
}
for (const auto& [box_index, area_index, literal] : box_in_area_lits) {
SetLiteralValue(literal,
box_to_area_pairs.contains({box_index, area_index}));
}
}
} // namespace sat
} // namespace operations_research