diff --git a/examples/cpp/magic_square_sat.cc b/examples/cpp/magic_square_sat.cc index 99348454e9..34771d3414 100644 --- a/examples/cpp/magic_square_sat.cc +++ b/examples/cpp/magic_square_sat.cc @@ -33,46 +33,52 @@ namespace sat { void MagicSquare(int size) { CpModelBuilder builder; - std::vector > square(size); - std::vector > transposed(size); - std::vector diag1; - std::vector diag2; + std::vector> square(size); std::vector all_variables; - Domain domain(1, size * size); + const Domain domain(1, size * size); for (int i = 0; i < size; ++i) { for (int j = 0; j < size; ++j) { const IntVar var = builder.NewIntVar(domain); square[i].push_back(var); - transposed[j].push_back(var); all_variables.push_back(var); - if (i == j) { - diag1.push_back(var); - } - if (i + j == size) { - diag2.push_back(var); - } } } - // All Diff. + // All cells take different values. for (int i = 0; i < size; ++i) { builder.AddAllDifferent(all_variables); } - const int sum = size * (size * size + 1) / 2; + // The sum on each row, columns and two main diagonals. + const int magic_value = size * (size * size + 1) / 2; + // Sum on rows. for (int i = 0; i < size; ++i) { - builder.AddEquality(LinearExpr::Sum(square[i]), sum); + LinearExpr sum; + for (int j = 0; j < size; ++j) { + sum += square[i][j]; + } + builder.AddEquality(sum, magic_value); } // Sum on columns. - for (int i = 0; i < size; ++i) { - builder.AddEquality(LinearExpr::Sum(transposed[i]), sum); + for (int j = 0; j < size; ++j) { + LinearExpr sum; + for (int i = 0; i < size; ++i) { + sum += square[i][j]; + } + builder.AddEquality(sum, magic_value); } // Sum on diagonals. - builder.AddEquality(LinearExpr::Sum(diag1), sum); - builder.AddEquality(LinearExpr::Sum(diag2), sum); + LinearExpr diag1_sum; + LinearExpr diag2_sum; + for (int i = 0; i < size; ++i) { + diag1_sum += square[i][i]; + diag2_sum += square[i][size - 1 - i]; + } + builder.AddEquality(diag1_sum, magic_value); + builder.AddEquality(diag2_sum, magic_value); Model model; model.Add(NewSatParameters(absl::GetFlag(FLAGS_params)));