work on diffn; fix non deterministic issue

This commit is contained in:
Laurent Perron
2023-12-28 21:42:18 +01:00
parent 1f65ccef44
commit 0be41f4129
9 changed files with 58 additions and 44 deletions

View File

@@ -136,6 +136,7 @@ cc_library(
"//ortools/util:sorted_interval_list",
"//ortools/util:strong_integers",
"//ortools/util:time_limit",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
@@ -2150,6 +2151,7 @@ cc_library(
"//ortools/util:bitset",
"//ortools/util:sorted_interval_list",
"//ortools/util:strong_integers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log:check",

View File

@@ -311,27 +311,26 @@ bool NonOverlappingRectanglesEnergyPropagator::Propagate() {
}
// We found a conflict, so we can afford to run the propagator again to
// search for a best explanation. This is specially the case since we only
// want to re-run it over the items that participate in the conflict, so its
// want to re-run it over the items that participate in the conflict, so it is
// a much smaller problem.
IntegerValue best_explanation_size = best_conflict->items.size();
bool found_improvement;
bool refined = false;
do {
found_improvement = false;
while (true) {
std::optional<Conflict> conflict = FindConflict(best_conflict->items);
if (!conflict.has_value()) break;
// We prefer an explanation with the least number of boxes.
if (conflict->items.size() < best_explanation_size) {
found_improvement = true;
best_explanation_size = conflict->items.size();
best_conflict = conflict;
refined = true;
if (conflict->items.size() >= best_explanation_size) {
// The new explanation isn't better than the old one. Stop trying.
break;
}
} while (found_improvement);
best_explanation_size = conflict->items.size();
best_conflict = conflict;
refined = true;
}
num_refined_conflicts_ += refined;
std::vector<RectangleInRange> generalized_explanation = GeneralizeExplanation(
best_conflict->rectangle_too_much_energy, best_conflict->items);
best_conflict->rectangle_with_too_much_energy, best_conflict->items);
if (best_explanation_size == 2) {
num_conflicts_two_boxes_++;
}
@@ -342,17 +341,17 @@ bool NonOverlappingRectanglesEnergyPropagator::Propagate() {
std::optional<NonOverlappingRectanglesEnergyPropagator::Conflict>
NonOverlappingRectanglesEnergyPropagator::FindConflict(
std::vector<RectangleInRange> active_box_ranges) {
const std::vector<Rectangle> rectangles_too_much_energy =
const std::vector<Rectangle> rectangles_with_too_much_energy =
FindRectanglesWithEnergyConflictMC(active_box_ranges, *random_, 1.0);
if (rectangles_too_much_energy.empty()) return std::nullopt;
if (rectangles_with_too_much_energy.empty()) return std::nullopt;
num_conflicts_++;
num_multiple_conflicts_ += rectangles_too_much_energy.size() > 1;
num_multiple_conflicts_ += rectangles_with_too_much_energy.size() > 1;
std::vector<RectangleInRange> best_explanation;
Rectangle best_rectangle;
for (const auto& r : rectangles_too_much_energy) {
for (const auto& r : rectangles_with_too_much_energy) {
std::vector<RectangleInRange> range_for_explanation =
GetEnergyConflictForRectangle(r, active_box_ranges);
CheckPropagationIsValid(range_for_explanation, r);
@@ -468,7 +467,7 @@ NonOverlappingRectanglesEnergyPropagator::GetEnergyConflictForRectangle(
[](const OverlapPerBox& a, const OverlapPerBox& b) {
return a.energy > b.energy;
});
IntegerValue available_energy = rectangle.Area();
const IntegerValue available_energy = rectangle.Area();
IntegerValue used_energy = 0;
std::vector<RectangleInRange> ranges_for_explanation;
ranges_for_explanation.reserve(energy_per_box.size());
@@ -496,8 +495,8 @@ int NonOverlappingRectanglesEnergyPropagator::RegisterWith(
void NonOverlappingRectanglesEnergyPropagator::CheckPropagationIsValid(
const std::vector<RectangleInRange>& ranges,
const Rectangle& rectangle_too_much_energy) {
const IntegerValue available_energy = rectangle_too_much_energy.Area();
const Rectangle& rectangle_with_too_much_energy) {
const IntegerValue available_energy = rectangle_with_too_much_energy.Area();
IntegerValue used_energy = 0;
for (const auto& range : ranges) {
const int b = range.box_index;
@@ -510,7 +509,7 @@ void NonOverlappingRectanglesEnergyPropagator::CheckPropagationIsValid(
// Each one of the boxes-in-range that we found on the cut does intersect
// the rectangle we found.
const auto intersection =
range.GetMinimumIntersection(rectangle_too_much_energy);
range.GetMinimumIntersection(rectangle_with_too_much_energy);
CHECK_GT(intersection.Area(), 0);
// It cannot intersect more than the size of the object.
CHECK_GE(x_.SizeMin(b), intersection.SizeX());

View File

@@ -52,7 +52,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface {
private:
struct Conflict {
std::vector<RectangleInRange> items;
Rectangle rectangle_too_much_energy;
Rectangle rectangle_with_too_much_energy;
};
std::optional<Conflict> FindConflict(
std::vector<RectangleInRange> active_box_ranges);
@@ -64,7 +64,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface {
bool BuildAndReportEnergyTooLarge(
const std::vector<RectangleInRange>& ranges);
void CheckPropagationIsValid(const std::vector<RectangleInRange>& ranges,
const Rectangle& rectangle_too_much_energy);
const Rectangle& rectangle_with_too_much_energy);
std::vector<RectangleInRange> GetEnergyConflictForRectangle(
const Rectangle& rectangle,
const std::vector<RectangleInRange>& active_box_ranges);

View File

@@ -23,6 +23,7 @@
#include <utility>
#include <vector>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
@@ -232,7 +233,7 @@ void ElementEncodings::Add(IntegerVariable var,
var_to_index_to_element_encodings_[var][exactly_one_index] = encoding;
}
const absl::flat_hash_map<int, std::vector<ValueLiteralPair>>&
const absl::btree_map<int, std::vector<ValueLiteralPair>>&
ElementEncodings::Get(IntegerVariable var) {
const auto& it = var_to_index_to_element_encodings_.find(var);
if (it == var_to_index_to_element_encodings_.end()) {
@@ -359,12 +360,12 @@ std::vector<LiteralValueValue> ProductDecomposer::TryToDecompose(
}
// Fill in the encodings for the left variable.
const absl::flat_hash_map<int, std::vector<ValueLiteralPair>>&
left_encodings = element_encodings_->Get(left.var);
const absl::btree_map<int, std::vector<ValueLiteralPair>>& left_encodings =
element_encodings_->Get(left.var);
// Fill in the encodings for the right variable.
const absl::flat_hash_map<int, std::vector<ValueLiteralPair>>&
right_encodings = element_encodings_->Get(right.var);
const absl::btree_map<int, std::vector<ValueLiteralPair>>& right_encodings =
element_encodings_->Get(right.var);
std::vector<int> compatible_keys;
for (const auto& [index, encoding] : left_encodings) {

View File

@@ -190,17 +190,17 @@ class ElementEncodings {
int exactly_one_index);
// Returns an empty map if there is no such encoding.
const absl::flat_hash_map<int, std::vector<ValueLiteralPair>>& Get(
const absl::btree_map<int, std::vector<ValueLiteralPair>>& Get(
IntegerVariable var);
// Get an unsorted set of variables appearing in element encodings.
const std::vector<IntegerVariable>& GetElementEncodedVariables() const;
private:
absl::flat_hash_map<IntegerVariable,
absl::flat_hash_map<int, std::vector<ValueLiteralPair>>>
absl::btree_map<IntegerVariable,
absl::btree_map<int, std::vector<ValueLiteralPair>>>
var_to_index_to_element_encodings_;
const absl::flat_hash_map<int, std::vector<ValueLiteralPair>>
const absl::btree_map<int, std::vector<ValueLiteralPair>>
empty_element_encoding_;
std::vector<IntegerVariable> element_encoded_variables_;
};

View File

@@ -862,7 +862,7 @@ void SatSolver::ProcessCurrentConflict() {
}
// Minimize the learned conflict.
MinimizeConflict(&learned_conflict_, &reason_used_to_infer_the_conflict_);
MinimizeConflict(&learned_conflict_);
// Minimize it further with binary clauses?
if (!binary_implication_graph_->IsEmpty()) {
@@ -2042,20 +2042,18 @@ void SatSolver::ComputeFirstUIPConflict(
for (const Literal literal : clause_to_expand) {
const BooleanVariable var = literal.Variable();
const int level = DecisionLevel(var);
if (level > 0) ++num_vars_at_positive_level_in_clause_to_expand;
if (level == 0) continue;
++num_vars_at_positive_level_in_clause_to_expand;
if (!is_marked_[var]) {
is_marked_.Set(var);
++num_new_vars_at_positive_level;
if (level == highest_level) {
++num_new_vars_at_positive_level;
++num_literal_at_highest_level_that_needs_to_be_processed;
} else if (level > 0) {
++num_new_vars_at_positive_level;
} else {
// Note that all these literals are currently false since the clause
// to expand was used to infer the value of a literal at this level.
DCHECK(trail_->Assignment().LiteralIsFalse(literal));
conflict->push_back(literal);
} else {
reason_used_to_infer_the_conflict->push_back(literal);
}
}
}
@@ -2292,9 +2290,7 @@ void SatSolver::ComputePBConflict(int max_trail_index,
LOG(FATAL) << "The code should never reach here.";
}
void SatSolver::MinimizeConflict(
std::vector<Literal>* conflict,
std::vector<Literal>* reason_used_to_infer_the_conflict) {
void SatSolver::MinimizeConflict(std::vector<Literal>* conflict) {
SCOPED_TIME_STAT(&stats_);
const int old_size = conflict->size();

View File

@@ -674,9 +674,7 @@ class SatSolver {
// Precondition: is_marked_ should be set to true for all the variables of
// the conflict. It can also contains false non-conflict variables that
// are implied by the negation of the 1-UIP conflict literal.
void MinimizeConflict(
std::vector<Literal>* conflict,
std::vector<Literal>* reason_used_to_infer_the_conflict);
void MinimizeConflict(std::vector<Literal>* conflict);
void MinimizeConflictExperimental(std::vector<Literal>* conflict);
void MinimizeConflictSimple(std::vector<Literal>* conflict);
void MinimizeConflictRecursively(std::vector<Literal>* conflict);

View File

@@ -31,6 +31,7 @@
#include "ortools/base/helpers.h"
#include "ortools/base/options.h"
#endif // __PORTABLE_PLATFORM__
#include "absl/algorithm/container.h"
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
@@ -964,10 +965,17 @@ void SharedBoundsManager::GetChangedBounds(
absl::MutexLock mutex_lock(&mutex_);
for (const int var : id_to_changed_variables_[id].PositionsSetAtLeastOnce()) {
variables->push_back(var);
}
id_to_changed_variables_[id].ClearAll();
// We need to report the bounds in a deterministic order as it is difficult to
// guarantee that nothing depend on the order in which the new bounds are
// processed.
absl::c_sort(*variables);
for (const int var : *variables) {
new_lower_bounds->push_back(synchronized_lower_bounds_[var]);
new_upper_bounds->push_back(synchronized_upper_bounds_[var]);
}
id_to_changed_variables_[id].ClearAll();
}
void SharedBoundsManager::UpdateDomains(std::vector<Domain>* domains) {

View File

@@ -51,6 +51,11 @@ struct ClosedInterval {
return start < other.start;
}
template <typename H>
friend H AbslHashValue(H h, const ClosedInterval& interval) {
return H::combine(std::move(h), interval.start, interval.end);
}
int64_t start = 0; // Inclusive.
int64_t end = 0; // Inclusive.
};
@@ -451,6 +456,11 @@ class Domain {
return intervals_ != other.intervals_;
}
template <typename H>
friend H AbslHashValue(H h, const Domain& domain) {
return H::combine(std::move(h), domain.intervals_);
}
/**
* Basic read-only std::vector<> wrapping to view a Domain as a sorted list of
* non-adjacent intervals. Note that we don't expose size() which might be