From c4fac7717467904f16b37ee4bc055fb018eb0fce Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Fri, 27 Sep 2024 14:55:35 +0200 Subject: [PATCH] [CP-SAT] Fix #4373 --- ortools/algorithms/sparse_permutation.cc | 26 ++++ ortools/algorithms/sparse_permutation.h | 26 ++++ ortools/algorithms/sparse_permutation_test.cc | 15 ++ ortools/graph/BUILD.bazel | 18 --- ortools/sat/BUILD.bazel | 4 +- ortools/sat/cp_model_symmetries.cc | 63 +++++++- ortools/sat/presolve_context.cc | 8 + ortools/sat/presolve_context.h | 3 + ortools/sat/symmetry_util.cc | 38 +++++ ortools/sat/symmetry_util.h | 13 ++ ortools/sat/symmetry_util_test.cc | 138 +++++++++++------- 11 files changed, 275 insertions(+), 77 deletions(-) diff --git a/ortools/algorithms/sparse_permutation.cc b/ortools/algorithms/sparse_permutation.cc index ab6801fef1..9d4b6a466d 100644 --- a/ortools/algorithms/sparse_permutation.cc +++ b/ortools/algorithms/sparse_permutation.cc @@ -81,4 +81,30 @@ std::string SparsePermutation::DebugString() const { return out; } +int SparsePermutation::Image(int element) const { + for (int c = 0; c < NumCycles(); ++c) { + int cur_element = LastElementInCycle(c); + for (int image : Cycle(c)) { + if (cur_element == element) { + return image; + } + cur_element = image; + } + } + return element; +} + +int SparsePermutation::InverseImage(int element) const { + for (int c = 0; c < NumCycles(); ++c) { + int cur_element = LastElementInCycle(c); + for (int image : Cycle(c)) { + if (image == element) { + return cur_element; + } + cur_element = image; + } + } + return element; +} + } // namespace operations_research diff --git a/ortools/algorithms/sparse_permutation.h b/ortools/algorithms/sparse_permutation.h index 0cbee4f9ec..ee9b70db5c 100644 --- a/ortools/algorithms/sparse_permutation.h +++ b/ortools/algorithms/sparse_permutation.h @@ -59,6 +59,11 @@ class SparsePermutation { // information with the loop above. Not sure it is needed though. int LastElementInCycle(int i) const; + // Returns the image of the given element or `element` itself if it is stable + // under the permutation. + int Image(int element) const; + int InverseImage(int element) const; + // To add a cycle to the permutation, repeatedly call AddToCurrentCycle() // with the cycle's orbit, then call CloseCurrentCycle(); // This shouldn't be called on trivial cycles (of length 1). @@ -76,6 +81,9 @@ class SparsePermutation { // Example: "(1 4 3) (5 9) (6 8 7)". std::string DebugString() const; + template + void ApplyToDenseCollection(Collection& span) const; + private: const int size_; std::vector cycles_; @@ -129,6 +137,24 @@ inline int SparsePermutation::LastElementInCycle(int i) const { return cycles_[cycle_ends_[i] - 1]; } +template +void SparsePermutation::ApplyToDenseCollection(Collection& span) const { + using T = typename Collection::value_type; + for (int c = 0; c < NumCycles(); ++c) { + const int last_element_idx = LastElementInCycle(c); + int element = last_element_idx; + T last_element = span[element]; + for (int image : Cycle(c)) { + if (image == last_element_idx) { + span[element] = last_element; + } else { + span[element] = span[image]; + } + element = image; + } + } +} + } // namespace operations_research #endif // OR_TOOLS_ALGORITHMS_SPARSE_PERMUTATION_H_ diff --git a/ortools/algorithms/sparse_permutation_test.cc b/ortools/algorithms/sparse_permutation_test.cc index a31927b2b8..44aead1cf8 100644 --- a/ortools/algorithms/sparse_permutation_test.cc +++ b/ortools/algorithms/sparse_permutation_test.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include "absl/container/flat_hash_set.h" @@ -73,6 +74,20 @@ TEST(SparsePermutationTest, Identity) { EXPECT_EQ(0, permutation.NumCycles()); } +TEST(SparsePermutationTest, ApplyToVector) { + std::vector v = {"0", "1", "2", "3", "4", "5", "6", "7", "8"}; + SparsePermutation permutation(v.size()); + permutation.AddToCurrentCycle(4); + permutation.AddToCurrentCycle(2); + permutation.AddToCurrentCycle(7); + permutation.CloseCurrentCycle(); + permutation.AddToCurrentCycle(6); + permutation.AddToCurrentCycle(1); + permutation.CloseCurrentCycle(); + permutation.ApplyToDenseCollection(v); + EXPECT_THAT(v, ElementsAre("0", "6", "7", "3", "2", "5", "1", "4", "8")); +} + // Generate a bunch of permutation on a 'huge' space, but that have very few // displacements. This would OOM if the implementation was O(N); we verify // that it doesn't. diff --git a/ortools/graph/BUILD.bazel b/ortools/graph/BUILD.bazel index 98e6c8a510..824d6cb89b 100644 --- a/ortools/graph/BUILD.bazel +++ b/ortools/graph/BUILD.bazel @@ -355,24 +355,6 @@ cc_library( ], ) -# need C++20 -#cc_test( -# name = "k_shortest_paths_test", -# srcs = ["k_shortest_paths_test.cc"], -# deps = [ -# ":graph", -# ":io", -# ":k_shortest_paths", -# ":shortest_paths", -# "//ortools/base:gmock_main", -# "@com_google_absl//absl/algorithm:container", -# "@com_google_absl//absl/log:check", -# "@com_google_absl//absl/random:distributions", -# "@com_google_absl//absl/strings", -# "@com_google_benchmark//:benchmark", -# ], -#) - # Flow problem protobuf representation proto_library( name = "flow_problem_proto", diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 5b2214ff06..c21059e05e 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -657,7 +657,6 @@ cc_library( hdrs = ["presolve_context.h"], deps = [ ":cp_model_cc_proto", - ":cp_model_checker", ":cp_model_loader", ":cp_model_mapping", ":cp_model_utils", @@ -668,6 +667,7 @@ cc_library( ":sat_parameters_cc_proto", ":sat_solver", ":util", + "//ortools/algorithms:sparse_permutation", "//ortools/base", "//ortools/base:mathutil", "//ortools/port:proto_utils", @@ -1163,6 +1163,7 @@ cc_library( "//ortools/algorithms:dynamic_partition", "//ortools/algorithms:sparse_permutation", "//ortools/base", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", ], @@ -1176,6 +1177,7 @@ cc_test( ":symmetry_util", "//ortools/algorithms:sparse_permutation", "//ortools/base:gmock_main", + "@com_google_absl//absl/types:span", ], ) diff --git a/ortools/sat/cp_model_symmetries.cc b/ortools/sat/cp_model_symmetries.cc index 1c53e17e9e..b489d655b6 100644 --- a/ortools/sat/cp_model_symmetries.cc +++ b/ortools/sat/cp_model_symmetries.cc @@ -895,6 +895,47 @@ std::vector BuildInequalityCoeffsForOrbitope( return out; } +void UpdateHintAfterFixingBoolToBreakSymmetry( + PresolveContext* context, int var, bool fixed_value, + const std::vector>& generators) { + if (!context->VarHasSolutionHint(var)) { + return; + } + const int64_t hinted_value = context->SolutionHint(var); + if (hinted_value == static_cast(fixed_value)) { + return; + } + + std::vector schrier_vector; + std::vector orbit; + GetSchreierVectorAndOrbit(var, generators, &schrier_vector, &orbit); + + bool found_target = false; + int target_var; + for (int v : orbit) { + if (context->VarHasSolutionHint(v) && + context->SolutionHint(v) == static_cast(fixed_value)) { + found_target = true; + target_var = v; + break; + } + } + if (!found_target) { + context->UpdateRuleStats( + "hint: couldn't transform infeasible hint properly"); + return; + } + + const std::vector generator_idx = + TracePoint(target_var, schrier_vector, generators); + for (const int i : generator_idx) { + context->PermuteHintValues(*generators[i]); + } + + DCHECK(context->VarHasSolutionHint(var)); + DCHECK_EQ(context->SolutionHint(var), fixed_value); +} + } // namespace bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { @@ -1010,6 +1051,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { // fixing do not exploit the full structure of these symmeteries. Note // however that the fixing via propagation above close cod105 even more // efficiently. + std::vector var_can_be_true_per_orbit(num_vars, -1); { std::vector tmp_to_clear; std::vector tmp_sizes(num_vars, 0); @@ -1050,7 +1092,11 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { } // We push all but the first one in each orbit. - if (tmp_sizes[rep] == 0) can_be_fixed_to_false.push_back(var); + if (tmp_sizes[rep] == 0) { + can_be_fixed_to_false.push_back(var); + } else { + var_can_be_true_per_orbit[rep] = var; + } tmp_sizes[rep] = 0; } } else { @@ -1131,7 +1177,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { } } - // Supper simple heuristic to use the orbitope or not. + // Super simple heuristic to use the orbitope or not. // // In an orbitope with an at most one on each row, we can fix the upper right // triangle. We could use a formula, but the loop is fast enough. @@ -1153,6 +1199,19 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { const int var = can_be_fixed_to_false[i]; if (orbits[var] == orbit_index) ++num_in_orbit; context->UpdateRuleStats("symmetry: fixed to false in general orbit"); + if (context->VarHasSolutionHint(var) && context->SolutionHint(var) == 1 && + var_can_be_true_per_orbit[orbits[var]] != -1) { + // We are breaking the symmetry in a way that makes the hint invalid. + // We want `var` to be false, so we would naively pick a symmetry to + // enforce that. But that will be wrong if we do this twice: after we + // permute the hint to fix the first one we would look for a symmetry + // group element that fixes the second one to false. But there are many + // of those, and picking the wrong one would risk making the first one + // true again. Since this is a AMO, fixing the one that is true doesn't + // have this problem. + UpdateHintAfterFixingBoolToBreakSymmetry( + context, var_can_be_true_per_orbit[orbits[var]], true, generators); + } if (!context->SetLiteralToFalse(var)) return false; } diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index a5dab8dd6b..c4789f9bff 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -33,6 +33,7 @@ #include "absl/numeric/int128.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" #include "ortools/port/proto_utils.h" @@ -725,6 +726,7 @@ void PresolveContext::UpdateConstraintVariableUsage(int c) { } bool PresolveContext::ConstraintVariableGraphIsUpToDate() const { + if (is_unsat_) return true; // We do not care in this case. return constraint_to_vars_.size() == working_model->constraints_size(); } @@ -1016,6 +1018,12 @@ bool PresolveContext::CanonicalizeAffineVariable(int ref, int64_t coeff, return true; } +void PresolveContext::PermuteHintValues(const SparsePermutation& perm) { + CHECK(hint_is_loaded_); + perm.ApplyToDenseCollection(hint_); + perm.ApplyToDenseCollection(hint_has_value_); +} + bool PresolveContext::StoreAffineRelation(int ref_x, int ref_y, int64_t coeff, int64_t offset, bool debug_no_recursion) { diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index faa7a39800..1dfac184ef 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -28,6 +28,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/logging.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" @@ -574,6 +575,8 @@ class PresolveContext { // the hint, in order to maintain it as best as possible during presolve. void LoadSolutionHint(); + void PermuteHintValues(const SparsePermutation& perm); + // Solution hint accessor. bool VarHasSolutionHint(int var) const { return hint_has_value_[var]; } int64_t SolutionHint(int var) const { return hint_[var]; } diff --git a/ortools/sat/symmetry_util.cc b/ortools/sat/symmetry_util.cc index 78edf8fe87..c1d96e0a38 100644 --- a/ortools/sat/symmetry_util.cc +++ b/ortools/sat/symmetry_util.cc @@ -18,6 +18,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/algorithms/dynamic_partition.h" @@ -194,5 +195,42 @@ std::vector GetOrbitopeOrbits( return orbits; } +void GetSchreierVectorAndOrbit( + int point, absl::Span> generators, + std::vector* schrier_vector, std::vector* orbit) { + schrier_vector->clear(); + *orbit = {point}; + if (generators.empty()) return; + schrier_vector->resize(generators[0]->Size(), -1); + absl::flat_hash_set orbit_set = {point}; + for (int i = 0; i < orbit->size(); ++i) { + const int orbit_element = (*orbit)[i]; + for (int i = 0; i < generators.size(); ++i) { + DCHECK_EQ(schrier_vector->size(), generators[i]->Size()); + const int image = generators[i]->Image(orbit_element); + if (image == orbit_element) continue; + const auto [it, inserted] = orbit_set.insert(image); + if (inserted) { + (*schrier_vector)[image] = i; + orbit->push_back(image); + } + } + } +} + +std::vector TracePoint( + int point, absl::Span schrier_vector, + absl::Span> generators) { + std::vector result; + while (schrier_vector[point] != -1) { + const SparsePermutation& perm = *generators[schrier_vector[point]]; + result.push_back(schrier_vector[point]); + const int next = perm.InverseImage(point); + DCHECK_NE(next, point); + point = next; + } + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/symmetry_util.h b/ortools/sat/symmetry_util.h index f045be430e..5e5e813d6e 100644 --- a/ortools/sat/symmetry_util.h +++ b/ortools/sat/symmetry_util.h @@ -62,6 +62,19 @@ std::vector GetOrbits( std::vector GetOrbitopeOrbits(int n, absl::Span> orbitope); +// See Chapter 7 of Butler, Gregory, ed. Fundamental algorithms for permutation +// groups. Berlin, Heidelberg: Springer Berlin Heidelberg, 1991. +void GetSchreierVectorAndOrbit( + int point, absl::Span> generators, + std::vector* schrier_vector, std::vector* orbit); + +// Given a schreier vector for a given base point and a point in the same orbit +// of the base point, returns a list of index of the `generators` to apply to +// get a permutation mapping the base point to get the given point. +std::vector TracePoint( + int point, absl::Span schrier_vector, + absl::Span> generators); + // Given the generators for a permutation group of [0, n-1], update it to // a set of generators of the group stabilizing the given element. // diff --git a/ortools/sat/symmetry_util_test.cc b/ortools/sat/symmetry_util_test.cc index 85a7d67481..9b3a5b19f5 100644 --- a/ortools/sat/symmetry_util_test.cc +++ b/ortools/sat/symmetry_util_test.cc @@ -13,9 +13,12 @@ #include "ortools/sat/symmetry_util.h" +#include #include +#include #include +#include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/gmock.h" @@ -25,24 +28,25 @@ namespace sat { namespace { using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +std::unique_ptr MakePerm( + int size, absl::Span> cycles) { + auto perm = std::make_unique(size); + for (const auto& cycle : cycles) { + for (const int x : cycle) { + perm->AddToCurrentCycle(x); + } + perm->CloseCurrentCycle(); + } + return perm; +} TEST(GetOrbitsTest, BasicExample) { const int n = 10; std::vector> generators; - generators.push_back(std::make_unique(n)); - generators[0]->AddToCurrentCycle(0); - generators[0]->AddToCurrentCycle(1); - generators[0]->AddToCurrentCycle(2); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(7); - generators[0]->AddToCurrentCycle(8); - generators[0]->CloseCurrentCycle(); - - generators.push_back(std::make_unique(n)); - generators[1]->AddToCurrentCycle(3); - generators[1]->AddToCurrentCycle(2); - generators[1]->AddToCurrentCycle(7); - generators[1]->CloseCurrentCycle(); + generators.push_back(MakePerm(n, {{0, 1, 2}, {7, 8}})); + generators.push_back(MakePerm(n, {{3, 2, 7}})); const std::vector orbits = GetOrbits(n, generators); for (const int i : std::vector{0, 1, 2, 3, 7, 8}) { EXPECT_EQ(orbits[i], 0); @@ -60,27 +64,8 @@ TEST(BasicOrbitopeExtractionTest, BasicExample) { const int n = 10; std::vector> generators; - generators.push_back(std::make_unique(n)); - generators[0]->AddToCurrentCycle(0); - generators[0]->AddToCurrentCycle(1); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(4); - generators[0]->AddToCurrentCycle(5); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(8); - generators[0]->AddToCurrentCycle(7); - generators[0]->CloseCurrentCycle(); - - generators.push_back(std::make_unique(n)); - generators[1]->AddToCurrentCycle(2); - generators[1]->AddToCurrentCycle(1); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(5); - generators[1]->AddToCurrentCycle(3); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(6); - generators[1]->AddToCurrentCycle(7); - generators[1]->CloseCurrentCycle(); + generators.push_back(MakePerm(n, {{0, 1}, {4, 5}, {8, 7}})); + generators.push_back(MakePerm(n, {{2, 1}, {5, 3}, {6, 7}})); const std::vector> orbitope = BasicOrbitopeExtraction(generators); @@ -99,27 +84,8 @@ TEST(BasicOrbitopeExtractionTest, NotAnOrbitopeBecauseOfDuplicates) { const int n = 10; std::vector> generators; - generators.push_back(std::make_unique(n)); - generators[0]->AddToCurrentCycle(0); - generators[0]->AddToCurrentCycle(1); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(4); - generators[0]->AddToCurrentCycle(5); - generators[0]->CloseCurrentCycle(); - generators[0]->AddToCurrentCycle(8); - generators[0]->AddToCurrentCycle(7); - generators[0]->CloseCurrentCycle(); - - generators.push_back(std::make_unique(n)); - generators[1]->AddToCurrentCycle(1); - generators[1]->AddToCurrentCycle(2); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(5); - generators[1]->AddToCurrentCycle(8); - generators[1]->CloseCurrentCycle(); - generators[1]->AddToCurrentCycle(6); - generators[1]->AddToCurrentCycle(9); - generators[1]->CloseCurrentCycle(); + generators.push_back(MakePerm(n, {{0, 1}, {4, 5}, {8, 7}})); + generators.push_back(MakePerm(n, {{1, 2}, {5, 8}, {6, 9}})); const std::vector> orbitope = BasicOrbitopeExtraction(generators); @@ -129,6 +95,66 @@ TEST(BasicOrbitopeExtractionTest, NotAnOrbitopeBecauseOfDuplicates) { EXPECT_THAT(orbitope[2], ElementsAre(8, 7)); } +TEST(GetSchreierVectorTest, Square) { + const int n = 4; + std::vector> generators; + generators.push_back(MakePerm(n, {{0, 1, 2, 3}})); + generators.push_back(MakePerm(n, {{1, 3}})); + + std::vector schrier_vector, orbit; + GetSchreierVectorAndOrbit(0, generators, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, 0, 0, 1)); +} + +TEST(GetSchreierVectorTest, ComplicatedGroup) { + // See Chapter 7 of Butler, Gregory, ed. Fundamental algorithms for + // permutation groups. Berlin, Heidelberg: Springer Berlin Heidelberg, 1991. + const int n = 11; + std::vector> generators; + generators.push_back(MakePerm(n, {{0, 3, 4, 10, 5, 9, 2, 1}, {6, 7}})); + generators.push_back(MakePerm(n, {{0, 3, 4, 10, 5, 9, 2, 1}, {7, 8}})); + generators.push_back(MakePerm(n, {{0, 3, 1, 2}, {4, 10, 9, 5}})); + + std::vector schrier_vector, orbit; + GetSchreierVectorAndOrbit(0, generators, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, 2, 2, 0, 0, 0, -1, -1, -1, 2, 0)); + std::vector generators_idx = TracePoint(9, schrier_vector, generators); + std::vector points = {"0", "1", "2", "3", "4", "5", + "6", "7", "8", "9", "10"}; + for (const int i : generators_idx) { + generators[i]->ApplyToDenseCollection(points); + } + // It needs to take the base point 0 to the traced point 9. + EXPECT_THAT(points, ElementsAre("9", "10", "1", "4", "5", "2", "7", "6", "8", + "3", "0")); + GetSchreierVectorAndOrbit(6, generators, &schrier_vector, &orbit); + EXPECT_THAT(orbit, UnorderedElementsAre(6, 7, 8)); + EXPECT_THAT(schrier_vector, + ElementsAre(-1, -1, -1, -1, -1, -1, -1, 0, 1, -1, -1)); +} + +TEST(GetSchreierVectorTest, ProjectivePlaneOrderTwo) { + const int n = 7; + std::vector> generators; + generators.push_back(MakePerm(n, {{0, 1, 3, 4, 6, 2, 5}})); + generators.push_back(MakePerm(n, {{1, 3}, {2, 4}})); + + std::vector schrier_vector, orbit; + GetSchreierVectorAndOrbit(0, generators, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, 0, 1, 0, 0, 0, 0)); + EXPECT_THAT(orbit, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6)); + + // Now let's see the stabilizer of the point 0. + std::vector> stabilizer; + + stabilizer.push_back(MakePerm(n, {{1, 3}, {2, 4}})); + stabilizer.push_back(MakePerm(n, {{3, 4}, {5, 6}})); + stabilizer.push_back(MakePerm(n, {{3, 5}, {4, 6}})); + + GetSchreierVectorAndOrbit(1, stabilizer, &schrier_vector, &orbit); + EXPECT_THAT(schrier_vector, ElementsAre(-1, -1, 0, 0, 1, 2, 2)); +} + } // namespace } // namespace sat } // namespace operations_research