diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index d069f70bf3..c19fda325a 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -3903,7 +3903,7 @@ cc_binary( "//ortools/base:file", "//ortools/base:path", "//ortools/util:file_util", - "//ortools/util:filelineiter", + "//ortools/util:logging", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/flags:flag", "@abseil-cpp//absl/log", diff --git a/ortools/sat/cp_model_solver_test.cc b/ortools/sat/cp_model_solver_test.cc index cda37825af..3b7a23e5c8 100644 --- a/ortools/sat/cp_model_solver_test.cc +++ b/ortools/sat/cp_model_solver_test.cc @@ -133,6 +133,7 @@ TEST(RelativeGapLimitTest, BooleanLinearOptimizationProblem) { Model model; SatParameters params; + params.set_num_workers(1); params.set_relative_gap_limit(1e10); // Should stop at the first solution! int num_solutions = 0; diff --git a/ortools/sat/sat_runner.cc b/ortools/sat/sat_runner.cc index adaec46cef..5863dd06f0 100644 --- a/ortools/sat/sat_runner.cc +++ b/ortools/sat/sat_runner.cc @@ -11,8 +11,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include +#include #include +#include #include "absl/flags/flag.h" #include "absl/flags/parse.h" @@ -23,6 +27,7 @@ #include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "google/protobuf/arena.h" #include "google/protobuf/text_format.h" @@ -37,6 +42,7 @@ #include "ortools/sat/sat_cnf_reader.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/util/file_util.h" +#include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" ABSL_FLAG( @@ -67,11 +73,20 @@ ABSL_FLAG(bool, wcnf_use_strong_slack, true, "enforce the fact that when it is true, the clause must be false."); ABSL_FLAG(bool, fingerprint_intermediate_solutions, false, "Attach the fingerprint of intermediate solutions to the output."); +ABSL_FLAG(bool, competition_mode, false, + "If true, output the log in a competition format."); namespace operations_research { namespace sat { namespace { +struct CompetitionLog { + std::function log_callback = nullptr; + std::function response_callback = nullptr; + std::function final_response_callback = + nullptr; +}; + void TryToRemoveSuffix(absl::string_view suffix, std::string* str) { if (file::Extension(*str) == suffix) *str = file::Stem(*str); } @@ -91,13 +106,79 @@ std::string ExtractName(absl::string_view full_filename) { } bool LoadProblem(const std::string& filename, absl::string_view hint_file, - absl::string_view domain_file, CpModelProto* cp_model) { + absl::string_view domain_file, CpModelProto* cp_model, + CompetitionLog* competition_log) { if (absl::EndsWith(filename, ".opb") || absl::EndsWith(filename, ".opb.bz2") || absl::EndsWith(filename, ".opb.gz")) { OpbReader reader; if (!reader.LoadAndValidate(filename, cp_model)) { - LOG(FATAL) << "Cannot load file '" << filename << "'."; + if (absl::GetFlag(FLAGS_competition_mode)) { + std::cout << "s UNSUPPORTED" << std::endl; + } else { + LOG(FATAL) << "Cannot load file '" << filename << "'."; + } + } + + if (absl::GetFlag(FLAGS_competition_mode)) { + competition_log->log_callback = [](const std::string& multi_line_input) { + if (multi_line_input.empty()) { + std::cout << "c" << std::endl; + return; + } + const std::vector lines = + absl::StrSplit(multi_line_input, '\n'); + for (const absl::string_view& line : lines) { + std::cout << "c " << line << std::endl; + } + }; + competition_log->response_callback = [](const CpSolverResponse& r) { + std::cout << "o " << static_cast(r.objective_value()) + << std::endl; + }; + const int num_variables = reader.num_variables(); + const bool has_objective = cp_model->has_objective(); + competition_log->final_response_callback = + [num_variables, has_objective](const CpSolverResponse& r) { + switch (r.status()) { + case CpSolverStatus::OPTIMAL: + if (has_objective) { + std::cout << "s OPTIMUM FOUND " << std::endl; + } else { + std::cout << "s SATISFIABLE" << std::endl; + } + break; + case CpSolverStatus::FEASIBLE: + std::cout << "s SATISFIABLE" << std::endl; + break; + case CpSolverStatus::INFEASIBLE: + std::cout << "s UNSATISFIABLE" << std::endl; + break; + case CpSolverStatus::MODEL_INVALID: + std::cout << "s UNSUPPORTED" << std::endl; + break; + case CpSolverStatus::UNKNOWN: + std::cout << "s UNKNOWN" << std::endl; + break; + default: + break; + } + std::string line; + for (int i = 0; i < num_variables; ++i) { + if (r.solution(i)) { + absl::StrAppend(&line, "x", i + 1, " "); + } else { + absl::StrAppend(&line, "-x", i + 1, " "); + } + if (line.size() >= 75) { + std::cout << "v " << line << std::endl; + line.clear(); + } + } + if (!line.empty()) { + std::cout << "v " << line << std::endl; + } + }; } } else if (absl::EndsWith(filename, ".cnf") || absl::EndsWith(filename, ".cnf.xz") || @@ -181,15 +262,28 @@ int Run() { google::protobuf::Arena arena; CpModelProto* cp_model = google::protobuf::Arena::Create(&arena); + CompetitionLog competition_log; if (!LoadProblem(absl::GetFlag(FLAGS_input), absl::GetFlag(FLAGS_hint_file), - absl::GetFlag(FLAGS_domain_file), cp_model)) { + absl::GetFlag(FLAGS_domain_file), cp_model, + &competition_log)) { CpSolverResponse response; response.set_status(CpSolverStatus::MODEL_INVALID); + if (competition_log.final_response_callback != nullptr) { + competition_log.final_response_callback(response); + } return EXIT_SUCCESS; } Model model; model.Add(NewSatParameters(parameters)); + if (competition_log.log_callback != nullptr) { + model.GetOrCreate()->AddInfoLoggingCallback( + competition_log.log_callback); + model.GetOrCreate()->set_log_to_stdout(false); + } + if (competition_log.response_callback != nullptr) { + model.Add(NewFeasibleSolutionObserver(competition_log.response_callback)); + } if (absl::GetFlag(FLAGS_fingerprint_intermediate_solutions)) { // Let's add a solution callback that will display the fingerprint of all // solutions. @@ -200,6 +294,9 @@ int Run() { })); } const CpSolverResponse response = SolveCpModel(*cp_model, &model); + if (competition_log.final_response_callback != nullptr) { + competition_log.final_response_callback(response); + } if (!absl::GetFlag(FLAGS_output).empty()) { if (absl::EndsWith(absl::GetFlag(FLAGS_output), "txt")) {