diff --git a/Makefile.cpp.mk b/Makefile.cpp.mk index 549f04fdd8..3a88297c5d 100644 --- a/Makefile.cpp.mk +++ b/Makefile.cpp.mk @@ -27,6 +27,7 @@ CPBINARIES = \ golomb$E \ linear_assignment_example$E \ magic_square$E \ + model_util \ network_routing$E \ nqueens$E \ solve_dimacs_assignment$E \ @@ -93,7 +94,9 @@ CONSTRAINT_SOLVER_LIB_OS = \ objs/expressions.$O\ objs/hybrid.$O\ objs/interval.$O\ + objs/io.$O\ objs/local_search.$O\ + objs/model.pb.$O\ objs/nogoods.$O\ objs/pack.$O\ objs/range_cst.$O\ @@ -116,11 +119,11 @@ objs/assignment.pb.$O:gen/constraint_solver/assignment.pb.cc $(CCC) $(CFLAGS) -c gen/constraint_solver/assignment.pb.cc $(OBJOUT)objs/assignment.pb.$O gen/constraint_solver/assignment.pb.cc:constraint_solver/assignment.proto - $(PROTOBUF_DIR)/bin/protoc --proto_path=constraint_solver --cpp_out=gen/constraint_solver constraint_solver/assignment.proto + $(PROTOBUF_DIR)/bin/protoc --proto_path=. --cpp_out=gen constraint_solver/assignment.proto gen/constraint_solver/assignment.pb.h:gen/constraint_solver/assignment.pb.cc -objs/constraint_solver.$O:constraint_solver/constraint_solver.cc +objs/constraint_solver.$O:constraint_solver/constraint_solver.cc gen/constraint_solver/model.pb.h $(CCC) $(CFLAGS) -c constraint_solver/constraint_solver.cc $(OBJOUT)objs/constraint_solver.$O objs/constraints.$O:constraint_solver/constraints.cc @@ -139,7 +142,7 @@ objs/demon_profiler.pb.$O:gen/constraint_solver/demon_profiler.pb.cc $(CCC) $(CFLAGS) -c gen/constraint_solver/demon_profiler.pb.cc $(OBJOUT)objs/demon_profiler.pb.$O gen/constraint_solver/demon_profiler.pb.cc:constraint_solver/demon_profiler.proto - $(PROTOBUF_DIR)/bin/protoc --proto_path=constraint_solver --cpp_out=gen/constraint_solver constraint_solver/demon_profiler.proto + $(PROTOBUF_DIR)/bin/protoc --proto_path=. --cpp_out=gen constraint_solver/demon_profiler.proto gen/constraint_solver/demon_profiler.pb.h:gen/constraint_solver/demon_profiler.pb.cc @@ -164,9 +167,20 @@ objs/hybrid.$O:constraint_solver/hybrid.cc objs/interval.$O:constraint_solver/interval.cc $(CCC) $(CFLAGS) -c constraint_solver/interval.cc $(OBJOUT)objs/interval.$O +objs/io.$O:constraint_solver/io.cc gen/constraint_solver/model.pb.h + $(CCC) $(CFLAGS) -c constraint_solver/io.cc $(OBJOUT)objs/io.$O + objs/local_search.$O:constraint_solver/local_search.cc $(CCC) $(CFLAGS) -c constraint_solver/local_search.cc $(OBJOUT)objs/local_search.$O +objs/model.pb.$O:gen/constraint_solver/model.pb.cc + $(CCC) $(CFLAGS) -c gen/constraint_solver/model.pb.cc $(OBJOUT)objs/model.pb.$O + +gen/constraint_solver/model.pb.cc:constraint_solver/model.proto + $(PROTOBUF_DIR)/bin/protoc --proto_path=. --cpp_out=gen constraint_solver/model.proto + +gen/constraint_solver/model.pb.h:gen/constraint_solver/model.pb.cc gen/constraint_solver/search_limit.pb.h + objs/nogoods.$O:constraint_solver/nogoods.cc $(CCC) $(CFLAGS) -c constraint_solver/nogoods.cc $(OBJOUT)objs/nogoods.$O @@ -189,7 +203,7 @@ objs/search_limit.pb.$O:gen/constraint_solver/search_limit.pb.cc $(CCC) $(CFLAGS) -c gen/constraint_solver/search_limit.pb.cc $(OBJOUT)objs/search_limit.pb.$O gen/constraint_solver/search_limit.pb.cc:constraint_solver/search_limit.proto - $(PROTOBUF_DIR)/bin/protoc --proto_path=constraint_solver --cpp_out=gen/constraint_solver constraint_solver/search_limit.proto + $(PROTOBUF_DIR)/bin/protoc --proto_path=. --cpp_out=gen constraint_solver/search_limit.proto gen/constraint_solver/search_limit.pb.h:gen/constraint_solver/search_limit.pb.cc @@ -234,7 +248,7 @@ objs/linear_solver.pb.$O:gen/linear_solver/linear_solver.pb.cc $(CCC) $(CFLAGS) -c gen/linear_solver/linear_solver.pb.cc $(OBJOUT)objs/linear_solver.pb.$O gen/linear_solver/linear_solver.pb.cc:linear_solver/linear_solver.proto - $(PROTOBUF_DIR)/bin/protoc --proto_path=linear_solver --cpp_out=gen/linear_solver linear_solver/linear_solver.proto + $(PROTOBUF_DIR)/bin/protoc --proto_path=. --cpp_out=gen linear_solver/linear_solver.proto gen/linear_solver/linear_solver.pb.h:gen/linear_solver/linear_solver.pb.cc @@ -416,64 +430,70 @@ solve_dimacs_assignment$E: $(ALGORITHMS_LIBS) $(BASE_LIBS) $(DIMACS_LIBS) $(GRAP # Pure CP and Routing Examples -objs/costas_array.$O: examples/costas_array.cc +objs/costas_array.$O: examples/costas_array.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/costas_array.cc $(OBJOUT)objs/costas_array.$O costas_array$E: $(CP_LIBS) $(BASE_LIBS) objs/costas_array.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/costas_array.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)costas_array$E -objs/cryptarithm.$O:examples/cryptarithm.cc +objs/cryptarithm.$O:examples/cryptarithm.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/cryptarithm.cc $(OBJOUT)objs/cryptarithm.$O cryptarithm$E: $(CP_LIBS) $(BASE_LIBS) objs/cryptarithm.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/cryptarithm.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)cryptarithm$E -objs/cvrptw.$O: examples/cvrptw.cc +objs/cvrptw.$O: examples/cvrptw.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/cvrptw.cc $(OBJOUT)objs/cvrptw.$O cvrptw$E: $(CP_LIBS) $(BASE_LIBS) objs/cvrptw.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/cvrptw.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)cvrptw$E -objs/dobble_ls.$O:examples/dobble_ls.cc +objs/dobble_ls.$O:examples/dobble_ls.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/dobble_ls.cc $(OBJOUT)objs/dobble_ls.$O dobble_ls$E: $(CP_LIBS) $(BASE_LIBS) objs/dobble_ls.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/dobble_ls.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)dobble_ls$E -objs/golomb.$O:examples/golomb.cc +objs/golomb.$O:examples/golomb.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/golomb.cc $(OBJOUT)objs/golomb.$O golomb$E: $(CP_LIBS) $(BASE_LIBS) objs/golomb.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/golomb.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)golomb$E -objs/magic_square.$O:examples/magic_square.cc +objs/magic_square.$O:examples/magic_square.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/magic_square.cc $(OBJOUT)objs/magic_square.$O magic_square$E: $(CP_LIBS) $(BASE_LIBS) objs/magic_square.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/magic_square.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)magic_square$E -objs/network_routing.$O:examples/network_routing.cc +objs/model_util.$O:examples/model_util.cc gen/constraint_solver/model.pb.h constraint_solver/constraint_solver.h + $(CCC) $(CFLAGS) -c examples/model_util.cc $(OBJOUT)objs/model_util.$O + +model_util$E: $(CP_LIBS) $(BASE_LIBS) objs/model_util.$O + $(CCC) $(CFLAGS) $(LDFLAGS) objs/model_util.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)model_util$E + +objs/network_routing.$O:examples/network_routing.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/network_routing.cc $(OBJOUT)objs/network_routing.$O network_routing$E: $(CP_LIBS) $(BASE_LIBS) $(GRAPH_LIBS) objs/network_routing.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/network_routing.$O $(CP_LIBS) $(GRAPH_LIBS) $(BASE_LIBS) $(EXEOUT)network_routing$E -objs/nqueens.$O: examples/nqueens.cc +objs/nqueens.$O: examples/nqueens.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/nqueens.cc $(OBJOUT)objs/nqueens.$O nqueens$E: $(CP_LIBS) $(BASE_LIBS) objs/nqueens.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/nqueens.$O $(CP_LIBS) $(BASE_LIBS) $(EXEOUT)nqueens$E -objs/tricks.$O: examples/tricks.cc +objs/tricks.$O: examples/tricks.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/tricks.cc $(OBJOUT)objs/tricks.$O -objs/global_arith.$O: examples/global_arith.cc +objs/global_arith.$O: examples/global_arith.cc constraint_solver/constraint_solver.h $(CCC) $(CFLAGS) -c examples/global_arith.cc $(OBJOUT)objs/global_arith.$O tricks$E: $(CPLIBS) $(BASE_LIBS) objs/tricks.$O objs/global_arith.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/tricks.$O objs/global_arith.$O $(CPLIBS) $(BASE_LIBS) $(EXEOUT)tricks$E -objs/tsp.$O: examples/tsp.cc +objs/tsp.$O: examples/tsp.cc constraint_solver/routing.h $(CCC) $(CFLAGS) -c examples/tsp.cc $(OBJOUT)objs/tsp.$O tsp$E: $(CP_LIBS) $(BASE_LIBS) objs/tsp.$O @@ -481,19 +501,19 @@ tsp$E: $(CP_LIBS) $(BASE_LIBS) objs/tsp.$O # Linear Programming Examples -objs/linear_solver_example.$O: examples/linear_solver_example.cc +objs/linear_solver_example.$O: examples/linear_solver_example.cc linear_solver/linear_solver.h $(CCC) $(CFLAGS) -c examples/linear_solver_example.cc $(OBJOUT)objs/linear_solver_example.$O linear_solver_example$E: $(LP_LIBS) $(BASE_LIBS) objs/linear_solver_example.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/linear_solver_example.$O $(LP_LIBS) $(BASE_LIBS) $(LDLPDEPS) $(EXEOUT)linear_solver_example$E -objs/linear_solver_example_with_protocol_buffers.$O: examples/linear_solver_example_with_protocol_buffers.cc +objs/linear_solver_example_with_protocol_buffers.$O: examples/linear_solver_example_with_protocol_buffers.cc linear_solver/linear_solver.h $(CCC) $(CFLAGS) -c examples/linear_solver_example_with_protocol_buffers.cc $(OBJOUT)objs/linear_solver_example_with_protocol_buffers.$O linear_solver_example_with_protocol_buffers$E: $(LP_LIBS) $(BASE_LIBS) objs/linear_solver_example_with_protocol_buffers.$O $(CCC) $(CFLAGS) $(LDFLAGS) objs/linear_solver_example_with_protocol_buffers.$O $(LP_LIBS) $(BASE_LIBS) $(LDLPDEPS) $(EXEOUT)linear_solver_example_with_protocol_buffers$E -objs/integer_solver_example.$O: examples/integer_solver_example.cc +objs/integer_solver_example.$O: examples/integer_solver_example.cc linear_solver/linear_solver.h $(CCC) $(CFLAGS) -c examples/integer_solver_example.cc $(OBJOUT)objs/integer_solver_example.$O integer_solver_example$E: $(LP_LIBS) $(BASE_LIBS) objs/integer_solver_example.$O diff --git a/constraint_solver/alldiff_cst.cc b/constraint_solver/alldiff_cst.cc index 1fe0a53422..7f42b5adff 100644 --- a/constraint_solver/alldiff_cst.cc +++ b/constraint_solver/alldiff_cst.cc @@ -18,6 +18,7 @@ #include "base/logging.h" #include "base/scoped_ptr.h" #include "constraint_solver/constraint_solveri.h" +#include "util/string_array.h" namespace operations_research { namespace { diff --git a/constraint_solver/assignment.cc b/constraint_solver/assignment.cc index b237587078..74a8fc51ee 100644 --- a/constraint_solver/assignment.cc +++ b/constraint_solver/assignment.cc @@ -295,8 +295,8 @@ void LoadElement(const hash_map& id_to_element_map, bool Assignment::Load(const string& filename) { File::Init(); - File* file = File::Create(filename, "r"); - if (file == NULL || !file->Open()) { + File* file = File::Open(filename, "r"); + if (file == NULL) { LOG(INFO) << "Cannot open " << filename; return false; } @@ -382,8 +382,8 @@ void Assignment::Load(const AssignmentProto& assignment_proto) { bool Assignment::Save(const string& filename) { File::Init(); - File* file = File::Create(filename, "w"); - if (file == NULL || !file->Open()) { + File* file = File::Open(filename, "w"); + if (file == NULL) { LOG(INFO) << "Cannot open " << filename; return false; } diff --git a/constraint_solver/constraint_solver.cc b/constraint_solver/constraint_solver.cc index 15ad102381..504531c93f 100644 --- a/constraint_solver/constraint_solver.cc +++ b/constraint_solver/constraint_solver.cc @@ -28,11 +28,14 @@ #include "base/macros.h" #include "base/scoped_ptr.h" #include "base/stringprintf.h" +#include "base/file.h" +#include "base/recordio.h" #include "zlib.h" #include "base/stringpiece.h" #include "base/concise_iterator.h" #include "base/map-util.h" #include "constraint_solver/constraint_solveri.h" +#include "constraint_solver/model.pb.h" #include "util/const_int_array.h" DEFINE_bool(cp_trace_demons, false, "trace all demon executions."); @@ -42,6 +45,8 @@ DEFINE_bool(cp_print_model, false, "use PrintModelVisitor on model before solving."); DEFINE_bool(cp_model_stats, false, "use StatisticsModelVisitor on model before solving."); +DEFINE_string(cp_export_file, "", "Export model to file using CPModelProto."); +DEFINE_bool(cp_no_solve, false, "Force failure at the beginning of a search"); void ConstraintSolverFailHere() { VLOG(3) << "Fail"; @@ -1352,6 +1357,7 @@ void Solver::Init() { InitCachedIntConstants(); // to be called after the SENTINEL is set. InitCachedConstraint(); // Cache the true constraint. InitBoolVarCaches(); + InitBuilders(); timer_->Restart(); } @@ -1373,6 +1379,7 @@ Solver::~Solver() { << "non empty list of searches when ending the solver"; delete search; DeleteDemonMonitor(demon_monitor_); + DeleteBuilders(); } const SolverParameters::TrailCompression @@ -1626,6 +1633,25 @@ void Solver::ProcessConstraints() { ModelVisitor* const visitor = MakeStatisticsModelVisitor(); Accept(visitor); } + if (!FLAGS_cp_export_file.empty()) { + File::Init(); + File* file = File::Open(FLAGS_cp_export_file, "w"); + if (file == NULL) { + LOG(WARNING) << "Cannot open " << FLAGS_cp_export_file; + } else { + CPModelProto export_proto; + ExportModel(&export_proto); + VLOG(1) << export_proto.DebugString(); + RecordWriter writer(file); + writer.WriteProtocolMessage(export_proto); + writer.Close(); + } + } + + if (FLAGS_cp_no_solve) { + LOG(INFO) << "Forcing early failure"; + Fail(); + } // Clear state before processing constraints. const int constraints_size = constraints_list_.size(); @@ -2381,7 +2407,7 @@ void DecisionVisitor::VisitTryRankFirst(Sequence* const sequence, int index) {} // ---------- ModelVisitor ---------- -// Enums. +// Tags for constraints, arguments, extensions. const char ModelVisitor::kAbs[] = "Abs"; const char ModelVisitor::kAllDifferent[] = "AllDifferent"; diff --git a/constraint_solver/constraint_solver.h b/constraint_solver/constraint_solver.h index 364350f6a1..0b7af40f5c 100644 --- a/constraint_solver/constraint_solver.h +++ b/constraint_solver/constraint_solver.h @@ -101,6 +101,12 @@ class BaseObject; class ClockTimer; class ConstIntArray; class Constraint; +class CPArgumentProto; +class CPConstraintProto; +class CPIntegerExpressionProto; +class CPIntervalVariableProto; +class CPModelBuilder; +class CPModelProto; class Decision; class DecisionBuilder; class DecisionVisitor; @@ -142,6 +148,7 @@ class SymmetryBreaker; class UnequalityVarCstCache; struct StateInfo; struct Trail; +template class ConstPtrArray; template class SimpleRevFIFO; // This enum is used internally to tag states in the search tree. @@ -297,6 +304,17 @@ class Solver { typedef ResultCallback1 IndexEvaluator1; typedef ResultCallback2 IndexEvaluator2; typedef ResultCallback3 IndexEvaluator3; + typedef ResultCallback2 + IntegerExpressionBuilder; + typedef ResultCallback2 ConstraintBuilder; + typedef ResultCallback2 + IntervalVariableBuilder; // Number of priorities for demons. static const int kNumPriorities = 3; @@ -556,6 +574,37 @@ class Solver { // Abandon the current branch in the search tree. A backtrack will follow. void Fail(); + // Exports the model to protobuf. This code will be called + // from inside the solver during the start of the search. + void ExportModel(CPModelProto* const proto) const; + // Exports the model to protobuf. Search monitors are useful to pass + // the objective and limits to the protobuf. + void ExportModel(const std::vector& monitors, + CPModelProto* const proto) const; + // Loads the model into the solver, and returns true upon success. + bool LoadModel(const CPModelProto& proto); + // Loads the model into the solver, appends search monitors to monitors, + // and returns true upon success. + bool LoadModel(const CPModelProto& proto, std::vector* monitors); + // Upgrades the model to the latest version. + static bool UpgradeModel(CPModelProto* const proto); + + // Registers a constraint builder. Ownership is passed to the solver. + void RegisterBuilder(const string& tag, + ConstraintBuilder* const builder); + // Registers a integer expression builder. Ownership is passed to the solver. + void RegisterBuilder(const string& tag, + IntegerExpressionBuilder* const builder); + // Registers a interval variable builder. Ownership is passed to the solver. + void RegisterBuilder(const string& tag, + IntervalVariableBuilder* const builder); + + ConstraintBuilder* GetConstraintBuilder(const string& tag) const; + IntegerExpressionBuilder* + GetIntegerExpressionBuilder(const string& tag) const; + IntervalVariableBuilder* GetIntervalVariableBuilder(const string& tag) const; + + // When SaveValue() is not the best way to go, one can create a reversible // action that will be called upon backtrack. The "fast" parameter // indicates whether we need restore all values saved through SaveValue() @@ -2297,6 +2346,8 @@ class Solver { void InitCachedIntConstants(); void InitCachedConstraint(); void InitBoolVarCaches(); + void InitBuilders(); + void DeleteBuilders(); // Naming string GetName(const PropagationBaseObject* object) const; @@ -2351,6 +2402,11 @@ class Solver { int constraint_index_; int additional_constraint_index_; + // Support for model loading. + hash_map expression_builders_; + hash_map constraint_builders_; + hash_map interval_builders_; + DISALLOW_COPY_AND_ASSIGN(Solver); }; diff --git a/constraint_solver/constraint_solver.swig b/constraint_solver/constraint_solver.swig index 6bb2f0950f..cd596c8435 100644 --- a/constraint_solver/constraint_solver.swig +++ b/constraint_solver/constraint_solver.swig @@ -25,6 +25,8 @@ DECLARE_bool(cp_trace_demons); DECLARE_bool(cp_print_model); DECLARE_bool(cp_model_stats); +DECLARE_string(cp_export_file); +DECLARE_bool(cp_no_solve); struct FailureProtect { jmp_buf exception_buffer; @@ -228,6 +230,10 @@ gflags.DEFINE_boolean('cp_print_model', False, 'prints the model before solving it.') gflags.DEFINE_boolean('cp_model_stats', False, 'displays model statistics before solving it.') +gflags.DEFINE_string('cp_export_file', '', + 'exports model to file using CPModelProto.') +gflags.DEFINE_boolean('cp_no_solve', False, + 'force failures at the beginning of a search.') } %pythoncode { @@ -861,20 +867,28 @@ namespace operations_research { // Add display methods on Solver and remove DebugString method. %ignore Solver::DebugString; +// Indentation is critical here as the code is copied verbatim in the +// python code. %feature("pythonappend") Solver::Solver %{ Solver.SetPythonFlags(FLAGS.cp_trace_demons, FLAGS.cp_print_model, - FLAGS.cp_model_stats) + FLAGS.cp_model_stats, + FLAGS.cp_export_file, + FLAGS.cp_no_solve) %} %extend Solver { static void SetPythonFlags(bool trace_demon, bool print_model, - bool model_stats) { + bool model_stats, + string export_file, + bool no_solve) { FLAGS_cp_trace_demons = trace_demon; FLAGS_cp_print_model = print_model; FLAGS_cp_model_stats = model_stats; + FLAGS_cp_export_file = export_file; + FLAGS_cp_no_solve = no_solve; } Constraint* TreeNoCycle(const std::vector& nexts, diff --git a/constraint_solver/constraint_solveri.h b/constraint_solver/constraint_solveri.h index a1ce8efdc5..f0d9b1e125 100644 --- a/constraint_solver/constraint_solveri.h +++ b/constraint_solver/constraint_solveri.h @@ -31,6 +31,7 @@ #include "base/bitmap.h" #include "base/map-util.h" #include "constraint_solver/constraint_solver.h" +#include "util/vector_map.h" class WallTimer; @@ -125,47 +126,6 @@ template class SimpleRevFIFO { int pos_; }; -// ---------- Pretty Print Helpers ---------- - -template string DebugStringArray(T* const* array, - int size, - const string& separator) { - string out; - for (int i = 0; i < size; ++i) { - if (i > 0) { - out.append(separator); - } - out.append(array[i]->DebugString()); - } - return out; -} - -template string NameArray(T* const* array, - int size, - const string& separator) { - string out; - for (int i = 0; i < size; ++i) { - if (i > 0) { - out.append(separator); - } - out.append(array[i]->name()); - } - return out; -} - -inline string Int64ArrayToString(const int64* const array, - int size, - const string& separator) { - string out; - for (int i = 0; i < size; ++i) { - if (i > 0) { - out.append(separator); - } - StringAppendF(&out, "%" GG_LL_FORMAT "d", array[i]); - } - return out; -} - // These methods represents generic demons that will call back a // method on the constraint during their Run method. @@ -870,6 +830,286 @@ class SearchLog : public SearchMonitor { int sliding_min_depth_; int sliding_max_depth_; }; + +// ---------- CPModelBuilder ----------- + +class CPModelBuilder { + public: + explicit CPModelBuilder(Solver* const solver) : solver_(solver) {} + ~CPModelBuilder() {} + + Solver* solver() const { return solver_; } + + // Builds integer expression from proto and stores it. It returns + // true upon success. + bool BuildFromProto(const CPIntegerExpressionProto& proto); + // Builds constraint from proto and returns it. + Constraint* BuildFromProto(const CPConstraintProto& proto); + // Builds interval variable from proto and stores it. It returns + // true upon success. + bool BuildFromProto(const CPIntervalVariableProto& proto); + // Returns stored integer expression. + IntExpr* IntegerExpression(int index) const; + // Returns stored interval variable. + IntervalVar* IntervalVariable(int index) const; + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + int64* to_fill); + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto , + IntExpr** to_fill); + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector* to_fill); + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector >* to_fill); + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector* to_fill); + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + IntervalVar** to_fill); + + bool ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector* to_fill); + + template bool ScanArguments(const string& type, + const P& proto, + A* to_fill) { + const int index = tags_.Index(type); + for (int i = 0; i < proto.arguments_size(); ++i) { + if (ScanOneArgument(index, proto.arguments(i), to_fill)) { + return true; + } + } + return false; + } + + int TagIndex(const string& tag) const { return tags_.Index(tag); } + + void AddTag(const string& tag) { tags_.Add(tag); } + + private: + Solver* const solver_; + std::vector expressions_; + std::vector intervals_; + VectorMap tags_; +}; + +// Implements a complete cache for model elements: expressions and constraints. +// Caching is based on the signature of the elements, as well as their type. +class ModelCache { + public: + enum VoidConstraintType { + VOID_FALSE_CONSTRAINT = 0, + VOID_TRUE_CONSTRAINT, + VOID_CONSTRAINT_MAX, + }; + + enum VarConstantConstraintType { + VAR_CONSTANT_EQUALITY = 0, + VAR_CONSTANT_GREATER_OR_EQUAL, + VAR_CONSTANT_LESS_OR_EQUAL, + VAR_CONSTANT_NON_EQUALITY, + VAR_CONSTANT_CONSTRAINT_MAX, + }; + + enum VarConstantConstantConstraintType { + VAR_CONSTANT_CONSTANT_BETWEEN = 0, + VAR_CONSTANT_CONSTANT_CONSTRAINT_MAX, + }; + + enum VarVarConstraintType { + VAR_VAR_EQUALITY = 0, + VAR_VAR_GREATER, + VAR_VAR_GREATER_OR_EQUAL, + VAR_VAR_LESS, + VAR_VAR_LESS_OR_EQUAL, + VAR_VAR_NON_EQUALITY, + VAR_VAR_CONSTRAINT_MAX, + }; + + enum VarExpressionType { + VAR_OPPOSITE = 0, + VAR_ABS, + VAR_SQUARE, + VAR_EXPRESSION_MAX, + }; + + enum VarConstantExpressionType { + VAR_CONSTANT_DIFFERENCE = 0, + VAR_CONSTANT_DIVIDE, + VAR_CONSTANT_PROD, + VAR_CONSTANT_MAX, + VAR_CONSTANT_MIN, + VAR_CONSTANT_SUM, + VAR_CONSTANT_EXPRESSION_MAX, + }; + + enum VarVarExpressionType { + VAR_VAR_DIFFERENCE = 0, + VAR_VAR_PROD, + VAR_VAR_MAX, + VAR_VAR_MIN, + VAR_VAR_SUM, + VAR_VAR_EXPRESSION_MAX, + }; + + enum VarConstantConstantExpressionType { + VAR_CONSTANT_CONSTANT_SEMI_CONTINUOUS = 0, + VAR_CONSTANT_CONSTANT_EXPRESSION_MAX, + }; + + enum VarConstantArrayExpressionType { + VAR_CONSTANT_ARRAY_ELEMENT = 0, + VAR_CONSTANT_ARRAY_EXPRESSION_MAX, + }; + + enum VarArrayExpressionType { + VAR_ARRAY_MAX = 0, + VAR_ARRAY_MIN, + VAR_ARRAY_SUM, + VAR_ARRAY_EXPRESSION_MAX, + }; + + explicit ModelCache(Solver* const solver); + virtual ~ModelCache(); + + // Void constraints. + + virtual Constraint* FindVoidConstraint(VoidConstraintType type) const = 0; + + virtual void InsertVoidConstraint(Constraint* const ct, + VoidConstraintType type) = 0; + + // Var Constant Constraints. + + virtual Constraint* FindVarConstantConstraint( + IntVar* const var, + int64 value, + VarConstantConstraintType type) const = 0; + + virtual void InsertVarConstantConstraint( + Constraint* const ct, + IntVar* const var, + int64 value, + VarConstantConstraintType type) = 0; + + // Var Constant Constant Constraints. + + virtual Constraint* FindVarConstantConstantConstraint( + IntVar* const var, + int64 value1, + int64 value2, + VarConstantConstantConstraintType type) const = 0; + + virtual void InsertVarConstantConstantConstraint( + Constraint* const ct, + IntVar* const var, + int64 value1, + int64 value2, + VarConstantConstantConstraintType type) = 0; + + // Var Var Constraints. + + virtual Constraint* FindVarVarConstraint( + IntVar* const var1, + IntVar* const var2, + VarVarConstraintType type) const = 0; + + virtual void InsertVarVarConstraint(Constraint* const ct, + IntVar* const var1, + IntVar* const var2, + VarVarConstraintType type) = 0; + + // Var Expressions. + + virtual IntExpr* FindVarExpression( + IntVar* const var, + VarExpressionType type) const = 0; + + virtual void InsertVarExpression(IntExpr* const expression, + IntVar* const var, + VarExpressionType type) = 0; + + // Var Constant Expressions . + + virtual IntExpr* FindVarConstantExpression( + IntVar* const var, + int64 value, + VarConstantExpressionType type) const = 0; + + virtual void InsertVarConstantExpression( + IntExpr* const expression, + IntVar* const var, + int64 value, + VarConstantExpressionType type) = 0; + + // Var Var Expressions. + + virtual IntExpr* FindVarVarExpression( + IntVar* const var1, + IntVar* const var2, + VarVarExpressionType type) const = 0; + + virtual void InsertVarVarExpression( + IntExpr* const expression, + IntVar* const var1, + IntVar* const var2, + VarVarExpressionType type) = 0; + + // Var Constant Constant Expressions. + + virtual IntExpr* FindVarConstantConstantExpression( + IntVar* const var, + int64 value1, + int64 value2, + VarConstantConstantExpressionType type) const = 0; + + virtual void InsertVarConstantConstantExpression( + IntExpr* const expression, + IntVar* const var, + int64 value1, + int64 value2, + VarConstantConstantExpressionType type) = 0; + + // Var Constant Array Expressions. + + virtual IntExpr* FindVarConstantArrayExpression( + IntVar* const var, + ConstIntArray* const values, + VarConstantArrayExpressionType type) const = 0; + + virtual void InsertVarConstantArrayExpression( + IntExpr* const expression, + IntVar* const var, + ConstIntArray* const values, + VarConstantArrayExpressionType type) = 0; + + // Var Array Expressions. + + virtual IntExpr* FindVarArrayExpression( + ConstPtrArray* const vars, + VarArrayExpressionType type) const = 0; + + virtual void InsertVarArrayExpression( + IntExpr* const expression, + ConstPtrArray* const vars, + VarArrayExpressionType type) = 0; + + Solver* solver() const; + + private: + Solver* const solver_; +}; } // namespace operations_research #endif // OR_TOOLS_CONSTRAINT_SOLVER_CONSTRAINT_SOLVERI_H_ diff --git a/constraint_solver/count_cst.cc b/constraint_solver/count_cst.cc index 813f2efc77..b151267aca 100644 --- a/constraint_solver/count_cst.cc +++ b/constraint_solver/count_cst.cc @@ -19,6 +19,7 @@ #include "base/stringprintf.h" #include "base/concise_iterator.h" #include "constraint_solver/constraint_solveri.h" +#include "util/string_array.h" namespace operations_research { diff --git a/constraint_solver/deviation.cc b/constraint_solver/deviation.cc index e789f639c5..44eba99c2c 100644 --- a/constraint_solver/deviation.cc +++ b/constraint_solver/deviation.cc @@ -21,6 +21,7 @@ #include "base/mathutil.h" #include "constraint_solver/constraint_solver.h" #include "constraint_solver/constraint_solveri.h" +#include "util/string_array.h" namespace operations_research { // Deviation Constraint, a constraint for the average absolute diff --git a/constraint_solver/element.cc b/constraint_solver/element.cc index 1d7bc580c3..6f74b3b76f 100644 --- a/constraint_solver/element.cc +++ b/constraint_solver/element.cc @@ -21,6 +21,7 @@ #include "base/stringprintf.h" #include "constraint_solver/constraint_solveri.h" #include "util/const_int_array.h" +#include "util/string_array.h" namespace operations_research { diff --git a/constraint_solver/expr_array.cc b/constraint_solver/expr_array.cc index 053f192c48..a953b46db6 100644 --- a/constraint_solver/expr_array.cc +++ b/constraint_solver/expr_array.cc @@ -19,6 +19,7 @@ #include "base/scoped_ptr.h" #include "base/stringprintf.h" #include "constraint_solver/constraint_solveri.h" +#include "util/string_array.h" namespace operations_research { namespace { diff --git a/constraint_solver/io.cc b/constraint_solver/io.cc new file mode 100644 index 0000000000..506b017370 --- /dev/null +++ b/constraint_solver/io.cc @@ -0,0 +1,2261 @@ +// Copyright 2010-2011 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/integral_types.h" +#include "base/logging.h" +#include "base/concise_iterator.h" +#include "base/map-util.h" +#include "base/stl_util.h" +#include "constraint_solver/constraint_solveri.h" +#include "constraint_solver/model.pb.h" +#include "util/bitset.h" +#include "util/vector_map.h" + +namespace operations_research { +namespace { +// ---------- Model Protobuf Writers ----------- + +// ----- First Pass visitor ----- + +// This visitor collects all constraints and expressions. It sorts the +// expressions, such that we can build them in sequence using +// previously created expressions. +class FirstPassVisitor : public ModelVisitor { + public: + virtual ~FirstPassVisitor() {} + + // Begin/End visit element. + virtual void BeginVisitModel(const string& solver_name) { + // Reset statistics. + expression_map_.clear(); + delegate_map_.clear(); + expression_list_.clear(); + constraint_list_.clear(); + interval_list_.clear(); + } + + virtual void EndVisitConstraint(const string& type_name, + const Constraint* const constraint) { + Register(constraint); + } + + virtual void EndVisitIntegerExpression(const string& type_name, + const IntExpr* const expression) { + Register(expression); + } + + virtual void VisitIntegerVariable(const IntVar* const variable, + const IntExpr* const delegate) { + if (delegate != NULL) { + delegate->Accept(this); + delegate_map_[variable] = delegate; + } + Register(variable); + } + + virtual void VisitIntervalVariable(const IntervalVar* const variable, + const string operation, + const IntervalVar* const delegate) { + if (delegate != NULL) { + delegate->Accept(this); + } + Register(variable); + } + + virtual void VisitIntervalVariable(const IntervalVar* const variable, + const string operation, + const IntervalVar* const * delegates, + int size) { + for (int i = 0; i < size; ++i) { + delegates[i]->Accept(this); + } + Register(variable); + } + + // Visit integer expression argument. + virtual void VisitIntegerExpressionArgument( + const string& arg_name, + const IntExpr* const argument) { + VisitSubArgument(argument); + } + + virtual void VisitIntegerVariableArrayArgument( + const string& arg_name, + const IntVar* const * arguments, + int size) { + for (int i = 0; i < size; ++i) { + VisitSubArgument(arguments[i]); + } + } + + // Visit interval argument. + virtual void VisitIntervalArgument(const string& arg_name, + const IntervalVar* const argument) { + VisitSubArgument(argument); + } + + virtual void VisitIntervalArrayArgument(const string& arg_name, + const IntervalVar* const * arguments, + int size) { + for (int i = 0; i < size; ++i) { + VisitSubArgument(arguments[i]); + } + } + + // Export + const hash_map& expression_map() const { + return expression_map_; + } + const hash_map& interval_map() const { + return interval_map_; + } + const hash_map& delegate_map() const { + return delegate_map_; + } + const std::vector& expression_list() const { + return expression_list_; + } + const std::vector& constraint_list() const { + return constraint_list_; + } + const std::vector& interval_list() const { + return interval_list_; + } + + private: + void Register(const IntExpr* const expression) { + if (!ContainsKey(expression_map_, expression)) { + const int index = expression_map_.size(); + CHECK_EQ(index, expression_list_.size()); + expression_map_[expression] = index; + expression_list_.push_back(expression); + } + } + + void Register(const Constraint* const constraint) { + constraint_list_.push_back(constraint); + } + + void Register(const IntervalVar* const interval) { + if (!ContainsKey(interval_map_, interval)) { + const int index = interval_map_.size(); + CHECK_EQ(index, interval_list_.size()); + interval_map_[interval] = index; + interval_list_.push_back(interval); + } + } + + void VisitSubArgument(const IntExpr* const expression) { + if (!ContainsKey(expression_map_, expression)) { + expression->Accept(this); + } + } + + void VisitSubArgument(const IntervalVar* const interval) { + if (!ContainsKey(interval_map_, interval)) { + interval->Accept(this); + } + } + + const string filename_; + hash_map expression_map_; + hash_map interval_map_; + hash_map delegate_map_; + std::vector expression_list_; + std::vector constraint_list_; + std::vector interval_list_; +}; + +// ----- Argument Holder ----- + +class ArgumentHolder { + public: + template void ExportToProto(VectorMap* const tags, + P* const proto) const { + for (ConstIter > it(integer_argument_); + !it.at_end(); + ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + arg_proto->set_integer_value(it->second); + } + + for (ConstIter > > it( + integer_array_argument_); !it.at_end(); ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + for (int i = 0; i < it->second.size(); ++i) { + arg_proto->add_integer_array(it->second[i]); + } + } + + for (ConstIter > > > it( + integer_matrix_argument_); !it.at_end(); ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + CPIntegerMatrixProto* const matrix_proto = + arg_proto->mutable_integer_matrix(); + const int columns = it->second.first; + CHECK_GT(columns, 0); + const int rows = it->second.second.size() / columns; + matrix_proto->set_rows(rows); + matrix_proto->set_columns(columns); + for (int i = 0; i < it->second.second.size(); ++i) { + matrix_proto->add_values(it->second.second[i]); + } + } + + for (ConstIter > it( + integer_expression_argument_); !it.at_end(); ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + arg_proto->set_integer_expression_index(it->second); + } + + for (ConstIter > > it( + integer_variable_array_argument_); !it.at_end(); ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + for (int i = 0; i < it->second.size(); ++i) { + arg_proto->add_integer_expression_array(it->second[i]); + } + } + + for (ConstIter > it(interval_argument_); + !it.at_end(); + ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + arg_proto->set_interval_index(it->second); + } + + for (ConstIter > > it( + interval_array_argument_); !it.at_end(); ++it) { + CPArgumentProto* const arg_proto = proto->add_arguments(); + arg_proto->set_argument_index(tags->Add(it->first)); + for (int i = 0; i < it->second.size(); ++i) { + arg_proto->add_interval_array(it->second[i]); + } + } + } + + const string& type_name() const { + return type_name_; + } + + void set_type_name(const string& type_name) { + type_name_ = type_name; + } + + void set_integer_argument(const string& arg_name, int64 value) { + integer_argument_[arg_name] = value; + } + + void set_integer_array_argument(const string& arg_name, + const int64* const values, + int size) { + for (int i = 0; i < size; ++i) { + integer_array_argument_[arg_name].push_back(values[i]); + } + } + + void set_integer_matrix_argument(const string& arg_name, + const int64* const * const values, + int rows, + int columns) { + pair > matrix = make_pair(columns, std::vector()); + integer_matrix_argument_[arg_name] = matrix; + std::vector* const vals = &integer_matrix_argument_[arg_name].second; + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < columns; ++j) { + vals->push_back(values[i][j]); + } + } + } + + void set_integer_expression_argument(const string& arg_name, int index) { + integer_expression_argument_[arg_name] = index; + } + + void set_integer_variable_array_argument(const string& arg_name, + const int* const indices, + int size) { + for (int i = 0; i < size; ++i) { + integer_variable_array_argument_[arg_name].push_back(indices[i]); + } + } + + void set_interval_argument(const string& arg_name, int index) { + interval_argument_[arg_name] = index; + } + + void set_interval_array_argument(const string& arg_name, + const int* const indices, + int size) { + for (int i = 0; i < size; ++i) { + interval_array_argument_[arg_name].push_back(indices[i]); + } + } + + int64 FindIntegerArgumentWithDefault(const string& arg_name, int64 def) { + return FindWithDefault(integer_argument_, arg_name, def); + } + + int64 FindIntegerArgumentOrDie(const string& arg_name) { + return FindOrDie(integer_argument_, arg_name); + } + + int64 FindIntegerExpressionArgumentOrDie(const string& arg_name) { + return FindOrDie(integer_expression_argument_, arg_name); + } + + private: + string type_name_; + hash_map integer_expression_argument_; + hash_map integer_argument_; + hash_map interval_argument_; + hash_map > integer_array_argument_; + hash_map > > integer_matrix_argument_; + hash_map > integer_variable_array_argument_; + hash_map > interval_array_argument_; +}; + +// ----- Second Pass Visitor ----- + +static const int kModelVersion = 1; + +// The second pass visitor will visited sorted expressions, interval +// vars and expressions and export them to a CPModelProto protocol +// buffer. +class SecondPassVisitor : public ModelVisitor { + public: + SecondPassVisitor(const FirstPassVisitor& first_pass, + CPModelProto* const model_proto) + : expression_map_(first_pass.expression_map()), + interval_map_(first_pass.interval_map()), + delegate_map_(first_pass.delegate_map()), + expression_list_(first_pass.expression_list()), + constraint_list_(first_pass.constraint_list()), + interval_list_(first_pass.interval_list()), + model_proto_(model_proto) {} + + virtual ~SecondPassVisitor() {} + + virtual void BeginVisitModel(const string& model_name) { + model_proto_->set_model(model_name); + model_proto_->set_version(kModelVersion); + PushArgumentHolder(); + for (ConstIter > it(expression_list_); + !it.at_end(); + ++it) { + (*it)->Accept(this); + } + + for (ConstIter > it(interval_list_); + !it.at_end(); + ++it) { + (*it)->Accept(this); + } + } + + virtual void EndVisitModel(const string& model_name) { + for (ConstIter > it(extensions_); + !it.at_end(); + ++it) { + WriteModelExtension(*it); + } + PopArgumentHolder(); + // Write tags. + for (int i = 0; i < tags_.size(); ++i) { + model_proto_->add_tags(tags_.Element(i)); + } + } + + virtual void BeginVisitConstraint(const string& type_name, + const Constraint* const constraint) { + PushArgumentHolder(); + } + + virtual void EndVisitConstraint(const string& type_name, + const Constraint* const constraint) { + // We ignore delegate constraints, they will be regenerated automatically. + if (constraint->IsDelegate()) { + return; + } + + const int index = model_proto_->constraints_size(); + CPConstraintProto* const constraint_proto = model_proto_->add_constraints(); + ExportToProto(constraint, constraint_proto, type_name, index); + PopArgumentHolder(); + } + + virtual void BeginVisitIntegerExpression(const string& type_name, + const IntExpr* const expression) { + PushArgumentHolder(); + } + + virtual void EndVisitIntegerExpression(const string& type_name, + const IntExpr* const expression) { + const int index = model_proto_->expressions_size(); + CPIntegerExpressionProto* const expression_proto = + model_proto_->add_expressions(); + ExportToProto(expression, expression_proto, type_name, index); + PopArgumentHolder(); + } + + virtual void BeginVisitExtension(const string& type_name) { + PushExtension(type_name); + } + + virtual void EndVisitExtension(const string& type_name) { + PopAndSaveExtension(); + } + + virtual void VisitIntegerArgument(const string& arg_name, int64 value) { + top()->set_integer_argument(arg_name, value); + } + + virtual void VisitIntegerArrayArgument(const string& arg_name, + const int64* const values, + int size) { + top()->set_integer_array_argument(arg_name, values, size); + } + + virtual void VisitIntegerMatrixArgument(const string& arg_name, + const int64* const * const values, + int rows, + int columns) { + top()->set_integer_matrix_argument(arg_name, values, rows, columns); + } + + virtual void VisitIntegerExpressionArgument( + const string& arg_name, + const IntExpr* const argument) { + top()->set_integer_expression_argument(arg_name, + FindExpressionIndex(argument)); + } + + virtual void VisitIntegerVariableArrayArgument( + const string& arg_name, + const IntVar* const * arguments, + int size) { + std::vector indices; + for (int i = 0; i < size; ++i) { + indices.push_back(FindExpressionIndex(arguments[i])); + } + top()->set_integer_variable_array_argument(arg_name, + indices.data(), + indices.size()); + } + + virtual void VisitIntervalArgument( + const string& arg_name, + const IntervalVar* argument) { + top()->set_interval_argument(arg_name, FindIntervalIndex(argument)); + } + + virtual void VisitIntervalArrayArgument( + const string& arg_name, + const IntervalVar* const * arguments, + int size) { + std::vector indices; + for (int i = 0; i < size; ++i) { + indices.push_back(FindIntervalIndex(arguments[i])); + } + top()->set_interval_array_argument(arg_name, + indices.data(), + indices.size()); + } + + virtual void VisitIntegerVariable(const IntVar* const variable, + const IntExpr* const delegate) { + if (delegate != NULL) { + const int index = model_proto_->expressions_size(); + CPIntegerExpressionProto* const var_proto = + model_proto_->add_expressions(); + var_proto->set_index(index); + var_proto->set_type_index(TagIndex(ModelVisitor::kIntegerVariable)); + CPArgumentProto* const sub_proto = var_proto->add_arguments(); + sub_proto->set_argument_index( + TagIndex(ModelVisitor::kExpressionArgument)); + sub_proto->set_integer_expression_index(FindExpressionIndex(delegate)); + } else { + const int index = model_proto_->expressions_size(); + CPIntegerExpressionProto* const var_proto = + model_proto_->add_expressions(); + var_proto->set_index(index); + var_proto->set_type_index(TagIndex(ModelVisitor::kIntegerVariable)); + if (variable->HasName()) { + var_proto->set_name(variable->name()); + } + if (variable->Size() == variable->Max() - variable->Min() + 1) { + // Contiguous + CPArgumentProto* const min_proto = var_proto->add_arguments(); + min_proto->set_argument_index(TagIndex(ModelVisitor::kMinArgument)); + min_proto->set_integer_value(variable->Min()); + CPArgumentProto* const max_proto = var_proto->add_arguments(); + max_proto->set_argument_index(TagIndex(ModelVisitor::kMaxArgument)); + max_proto->set_integer_value(variable->Max()); + } else { + // Non Contiguous + CPArgumentProto* const values_proto = var_proto->add_arguments(); + values_proto->set_argument_index( + TagIndex(ModelVisitor::kValuesArgument)); + scoped_ptr it(variable->MakeDomainIterator(false)); + for (it->Init(); it->Ok(); it->Next()) { + values_proto->add_integer_array(it->Value()); + } + } + } + } + + virtual void VisitIntervalVariable(const IntervalVar* const variable, + const string operation, + const IntervalVar* const delegate) { + if (delegate != NULL) { + const int index = model_proto_->intervals_size(); + CPIntervalVariableProto* const var_proto = model_proto_->add_intervals(); + var_proto->set_index(index); + var_proto->set_type_index(TagIndex(ModelVisitor::kIntervalVariable)); + CPArgumentProto* const sub_proto = var_proto->add_arguments(); + sub_proto->set_argument_index(TagIndex(operation)); + sub_proto->set_interval_index(FindIntervalIndex(delegate)); + } else { + const int index = model_proto_->intervals_size(); + CPIntervalVariableProto* const var_proto = model_proto_->add_intervals(); + var_proto->set_index(index); + var_proto->set_type_index(TagIndex(ModelVisitor::kIntervalVariable)); + if (variable->HasName()) { + var_proto->set_name(variable->name()); + } + CPArgumentProto* const start_min_proto = var_proto->add_arguments(); + start_min_proto->set_argument_index( + TagIndex(ModelVisitor::kStartMinArgument)); + start_min_proto->set_integer_value(variable->StartMin()); + CPArgumentProto* const start_max_proto = var_proto->add_arguments(); + start_max_proto->set_argument_index( + TagIndex(ModelVisitor::kStartMaxArgument)); + start_max_proto->set_integer_value(variable->StartMax()); + CPArgumentProto* const end_min_proto = var_proto->add_arguments(); + end_min_proto->set_argument_index( + TagIndex(ModelVisitor::kEndMinArgument)); + end_min_proto->set_integer_value(variable->EndMin()); + CPArgumentProto* const end_max_proto = var_proto->add_arguments(); + end_max_proto->set_argument_index( + TagIndex(ModelVisitor::kEndMaxArgument)); + end_max_proto->set_integer_value(variable->EndMax()); + CPArgumentProto* const duration_min_proto = var_proto->add_arguments(); + duration_min_proto->set_argument_index( + TagIndex(ModelVisitor::kDurationMinArgument)); + duration_min_proto->set_integer_value(variable->DurationMin()); + CPArgumentProto* const duration_max_proto = var_proto->add_arguments(); + duration_max_proto->set_argument_index( + TagIndex(ModelVisitor::kDurationMaxArgument)); + duration_max_proto->set_integer_value(variable->DurationMax()); + CPArgumentProto* const optional_proto = var_proto->add_arguments(); + optional_proto->set_argument_index( + TagIndex(ModelVisitor::kOptionalArgument)); + optional_proto->set_integer_value(!variable->MustBePerformed()); + } + } + + virtual void VisitIntervalVariable(const IntervalVar* const variable, + const string operation, + const IntervalVar* const * delegates, + int size) { + CHECK_NOTNULL(delegates); + CHECK_GT(size, 0); + const int index = model_proto_->intervals_size(); + CPIntervalVariableProto* const var_proto = model_proto_->add_intervals(); + var_proto->set_index(index); + var_proto->set_type_index(TagIndex(ModelVisitor::kIntervalVariable)); + CPArgumentProto* const sub_proto = var_proto->add_arguments(); + sub_proto->set_argument_index(TagIndex(operation)); + for (int i = 0; i < size; ++i) { + sub_proto->add_interval_array(FindIntervalIndex(delegates[i])); + } + } + + int TagIndex(const string& tag) { + return tags_.Add(tag); + } + + private: + void WriteModelExtension(ArgumentHolder* const holder) { + CHECK_NOTNULL(holder); + if (holder->type_name().compare(kObjectiveExtension) == 0) { + WriteObjective(holder); + } else if (holder->type_name().compare(kSearchLimitExtension) == 0) { + WriteSearchLimit(holder); + } else if (holder->type_name().compare(kVariableGroupExtension) == 0) { + WriteVariableGroup(holder); + } else { + LOG(INFO) << "Unknown model extension :" << holder->type_name(); + } + } + + void WriteObjective(ArgumentHolder* const holder) { + CHECK_NOTNULL(holder); + const bool maximize = holder->FindIntegerArgumentOrDie(kMaximizeArgument); + const int64 step = holder->FindIntegerArgumentOrDie(kStepArgument); + const int objective_index = + holder->FindIntegerExpressionArgumentOrDie(kExpressionArgument); + CPObjectiveProto* const objective_proto = model_proto_->mutable_objective(); + objective_proto->set_maximize(maximize); + objective_proto->set_step(step); + objective_proto->set_objective_index(objective_index); + } + + void WriteSearchLimit(ArgumentHolder* const holder) { + CHECK_NOTNULL(holder); + SearchLimitProto* const proto = model_proto_->mutable_search_limit(); + proto->set_time(holder->FindIntegerArgumentWithDefault(kTimeLimitArgument, + kint64max)); + proto->set_branches(holder->FindIntegerArgumentWithDefault( + kBranchesLimitArgument, + kint64max)); + proto->set_failures(holder->FindIntegerArgumentWithDefault( + kFailuresLimitArgument, + kint64max)); + proto->set_solutions(holder->FindIntegerArgumentWithDefault( + kSolutionLimitArgument, + kint64max)); + proto->set_smart_time_check(holder->FindIntegerArgumentWithDefault( + kSmartTimeCheckArgument, + false)); + proto->set_cumulative(holder->FindIntegerArgumentWithDefault( + kCumulativeArgument, + false)); + } + + void WriteVariableGroup(ArgumentHolder* const holder) { + CPVariableGroup* const group_proto = model_proto_->add_variable_groups(); + holder->ExportToProto(&tags_, group_proto); + } + + template void ExportToProto(const A* const argument, + P* const proto, + const string& type_name, + int index) { + CHECK_NOTNULL(proto); + CHECK_NOTNULL(argument); + proto->set_index(index); + proto->set_type_index(TagIndex(type_name)); + if (argument->HasName()) { + proto->set_name(argument->name()); + } + top()->ExportToProto(&tags_, proto); + for (ConstIter > it(extensions_); + !it.at_end(); + ++it) { + CPExtensionProto* const extension_proto = proto->add_extensions(); + extension_proto->set_type_index(TagIndex((*it)->type_name())); + (*it)->ExportToProto(&tags_, extension_proto); + } + } + + void PushArgumentHolder() { + holders_.push_back(new ArgumentHolder); + } + + void PopArgumentHolder() { + CHECK(!holders_.empty()); + delete holders_.back(); + holders_.pop_back(); + STLDeleteElements(&extensions_); + extensions_.clear(); + } + + void PushExtension(const string& type_name) { + PushArgumentHolder(); + holders_.back()->set_type_name(type_name); + } + + void PopAndSaveExtension() { + CHECK(!holders_.empty()); + extensions_.push_back(holders_.back()); + holders_.pop_back(); + } + + ArgumentHolder* top() const { + CHECK(!holders_.empty()); + return holders_.back(); + } + + int FindExpressionIndex(const IntExpr* const expression) const { + const int result = FindWithDefault(expression_map_, expression, -1); + CHECK_NE(-1, result); + return result; + } + + int FindIntervalIndex(const IntervalVar* const interval) const { + const int result = FindWithDefault(interval_map_, interval, -1); + CHECK_NE(-1, result); + return result; + } + + hash_map expression_map_; + hash_map interval_map_; + hash_map delegate_map_; + std::vector expression_list_; + std::vector constraint_list_; + std::vector interval_list_; + CPModelProto* const model_proto_; + + std::vector holders_; + std::vector extensions_; + VectorMap tags_; +}; + +// ---------- Model Protocol Reader ---------- + +// ----- Utility Class for Callbacks ----- + +template class ArrayWithOffset : public BaseObject { + public: + ArrayWithOffset(int64 index_min, int64 index_max) + : index_min_(index_min), + index_max_(index_max), + values_(new T[index_max - index_min + 1]) { + DCHECK_LE(index_min, index_max); + } + + virtual ~ArrayWithOffset() {} + + virtual T Evaluate(int64 index) const { + DCHECK_GE(index, index_min_); + DCHECK_LE(index, index_max_); + return values_[index - index_min_]; + } + + void SetValue(int64 index, T value) { + DCHECK_GE(index, index_min_); + DCHECK_LE(index, index_max_); + values_[index - index_min_] = value; + } + + private: + const int64 index_min_; + const int64 index_max_; + scoped_array values_; +}; + +template void MakeCallbackFromProto( + CPModelBuilder* const builder, + const CPExtensionProto& proto, + int tag_index, + ResultCallback1** callback) { + DCHECK_EQ(tag_index, proto.type_index()); + Solver* const solver = builder->solver(); + int64 index_min = 0; + CHECK(builder->ScanArguments(ModelVisitor::kMinArgument, proto, &index_min)); + int64 index_max = 0; + CHECK(builder->ScanArguments(ModelVisitor::kMaxArgument, proto, &index_max)); + std::vector values; + CHECK(builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &values)); + ArrayWithOffset* const array = + solver->RevAlloc(new ArrayWithOffset(index_min, index_max)); + for (int i = index_min; i <= index_max; ++i) { + array->SetValue(i, values[i - index_min]); + } + *callback = NewPermanentCallback(array, &ArrayWithOffset::Evaluate); +} + +#define VERIFY(expr) if (!(expr)) return NULL +#define VERIFY_EQ(e1, e2) if ((e1) != (e2)) return NULL + +// ----- kAbs ----- + +IntExpr* BuildAbs(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + return builder->solver()->MakeAbs(expr); +} + +// ----- kAllDifferent ----- + +Constraint* BuildAllDifferent(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + int64 range = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kRangeArgument, proto, &range)); + return builder->solver()->MakeAllDifferent(vars, range); +} + +// ----- kAllowedAssignments ----- + +Constraint* BuildAllowedAssignments(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector > tuples; + VERIFY(builder->ScanArguments(ModelVisitor::kTuplesArgument, proto, &tuples)); + return builder->solver()->MakeAllowedAssignments(vars, tuples); +} + +// ----- kBetween ----- + +Constraint* BuildBetween(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + int64 value_min = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMinArgument, proto, &value_min)); + int64 value_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMaxArgument, proto, &value_max)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + return builder->solver()->MakeBetweenCt(expr->Var(), value_min, value_max); +} + +// ----- kConvexPiecewise ----- +IntExpr* BuildConvexPiecewise(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + int64 early_cost = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kEarlyCostArgument, + proto, + &early_cost)); + int64 early_date = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kEarlyDateArgument, + proto, + &early_date)); + int64 late_cost = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kLateCostArgument, + proto, + &late_cost)); + int64 late_date = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kLateDateArgument, + proto, + &late_date)); + return builder->solver()->MakeConvexPiecewiseExpr(expr->Var(), + early_cost, + early_date, + late_date, + late_cost); +} + +// ----- kCountEqual ----- + +Constraint* BuildCountEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + int64 count = 0; + if (builder->ScanArguments(ModelVisitor::kCountArgument, proto, &value)) { + return builder->solver()->MakeCount(vars, value, count); + } else { + IntExpr* count_expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kCountArgument, + proto, + &count_expr)); + return builder->solver()->MakeCount(vars, value, count_expr->Var()); + } +} + +// ----- kCumulative ----- + +Constraint* BuildCumulative(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalsArgument, + proto, + &vars)); + std::vector demands; + VERIFY(builder->ScanArguments(ModelVisitor::kDemandsArgument, + proto, + &demands)); + int64 capacity; + VERIFY(builder->ScanArguments(ModelVisitor::kCapacityArgument, + proto, + &capacity)); + string name; + if (proto.has_name()) { + name = proto.name(); + } + return builder->solver()->MakeCumulative(vars, demands, capacity, name); +} + +// ----- kDeviation ----- + +Constraint* BuildDeviation(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeDeviation(vars, target->Var(), value); +} + +// ----- kDifference ----- + +IntExpr* BuildDifference(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeDifference(left, right); + } + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeDifference(value, expr); +} + +// ----- kDistribute ----- + +Constraint* BuildDistribute(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + if (builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)) { + std::vector cards; + if (builder->ScanArguments(ModelVisitor::kCardsArgument, proto, &cards)) { + std::vector values; + if (builder->ScanArguments(ModelVisitor::kValuesArgument, + proto, + &values)) { + return builder->solver()->MakeDistribute(vars, values, cards); + } else { + return builder->solver()->MakeDistribute(vars, cards); + } + } else { + int64 card_min = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMinArgument, + proto, + &card_min)); + int64 card_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMaxArgument, + proto, + &card_max)); + int64 card_size = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kSizeArgument, + proto, + &card_size)); + return builder->solver()->MakeDistribute(vars, + card_min, + card_max, + card_size); + } + } else { + std::vector cards; + VERIFY(builder->ScanArguments(ModelVisitor::kCardsArgument, proto, &cards)); + return builder->solver()->MakeDistribute(vars, cards); + } +} + +// ----- kDivide ----- + +IntExpr* BuildDivide(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeDiv(expr, value); +} + +// ----- kDurationExpr ----- + +IntExpr* BuildDurationExpr(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntervalVar* var = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalArgument, proto, &var)); + return var->DurationExpr(); +} + +// ----- kElement ----- + +IntExpr* BuildElement(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* index = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, proto, &index)); + std::vector values; + if (proto.extensions_size() > 0) { + VERIFY_EQ(1, proto.extensions_size()); + Solver::IndexEvaluator1 * callback = NULL; + const int extension_tag_index = + builder->TagIndex(ModelVisitor::kInt64ToInt64Extension); + MakeCallbackFromProto(builder, + proto.extensions(0), + extension_tag_index, + &callback); + return builder->solver()->MakeElement(callback, index->Var()); + } + if (builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &values)) { + return builder->solver()->MakeElement(values, index->Var()); + } + std::vector vars; + if (builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)) { + return builder->solver()->MakeElement(vars, index->Var()); + } + return NULL; +} + +// ----- kElementEqual ----- +// TODO(user): Add API on solver and uncomment this method. +/* + Constraint* BuildElementEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, + proto, + &target)); + std::vector values; + if (builder->ScanArguments(ModelVisitor::kValuesArgument, + proto, + &values)) { + IntExpr* index = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, + proto, + &index)); + return builder->solver()->MakeElement(values, index->Var()); + } + std::vector vars; + if (builder->ScanArguments(ModelVisitor::kVarsArgument, + proto, + &vars)) { + IntExpr* index = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, + proto, + &index)); + return builder->solver()->MakeElement(vars, index->Var()); + } + return NULL; + } +*/ + +// ----- kEndExpr ----- + +IntExpr* BuildEndExpr(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntervalVar* var = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalArgument, proto, &var)); + return var->EndExpr(); +} + +// ----- kEquality ----- + +Constraint* BuildEquality(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeEquality(left->Var(), right->Var()); + } + IntExpr* expr = NULL; + if (builder->ScanArguments(ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeEquality(expr->Var(), value); + } + return NULL; +} + +// ----- kFalseConstraint ----- + +Constraint* BuildFalseConstraint(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + return builder->solver()->MakeFalseConstraint(); +} + +// ----- kGreater ----- + +Constraint* BuildGreater(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* left = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)); + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeGreater(left->Var(), right->Var()); +} + +// ----- kGreaterOrEqual ----- + +Constraint* BuildGreaterOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeGreaterOrEqual(left->Var(), right->Var()); + } + IntExpr* expr = NULL; + if (builder->ScanArguments(ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeGreaterOrEqual(expr->Var(), value); + } + return NULL; +} + +// ----- kIntegerVariable ----- + +IntExpr* BuildIntegerVariable(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* sub_expression = NULL; + if (builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &sub_expression)) { + IntVar* const result = sub_expression->Var(); + if (proto.has_name()) { + result->set_name(proto.name()); + } + return result; + } + int64 var_min = 0; + if (builder->ScanArguments(ModelVisitor::kMinArgument, proto, &var_min)) { + int64 var_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMaxArgument, proto, &var_max)); + IntVar* const result = builder->solver()->MakeIntVar(var_min, var_max); + if (proto.has_name()) { + result->set_name(proto.name()); + } + return result; + } + std::vector values; + if (builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &values)) { + IntVar* const result = builder->solver()->MakeIntVar(values); + if (proto.has_name()) { + result->set_name(proto.name()); + } + return result; + } + return NULL; +} + +// ----- kIntervalBinaryRelation ----- + +Constraint* BuildIntervalBinaryRelation(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntervalVar* left = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)); + IntervalVar* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + int64 relation = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kRelationArgument, + proto, + &relation)); + Solver::BinaryIntervalRelation rel = + static_cast(relation); + return builder->solver()->MakeIntervalVarRelation(left, rel, right); +} + +// ----- kIntervalDisjunction ----- + +Constraint* BuildIntervalDisjunction(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntervalVar* left = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)); + IntervalVar* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeTemporalDisjunction(left, right, target->Var()); +} + +// ----- kIntervalUnaryRelation ----- + +Constraint* BuildIntervalUnaryRelation(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntervalVar* interval = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalArgument, + proto, + &interval)); + int64 date = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &date)); + int64 relation = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kRelationArgument, + proto, + &relation)); + Solver::UnaryIntervalRelation rel = + static_cast(relation); + return builder->solver()->MakeIntervalVarRelation(interval, rel, date); +} + +// ----- kIntervalVariable ----- + +IntervalVar* BuildIntervalVariable(CPModelBuilder* const builder, + const CPIntervalVariableProto& proto) { + Solver* const solver = builder->solver(); + int64 start_min = 0; + if (builder->ScanArguments(ModelVisitor::kStartMinArgument, + proto, + &start_min)) { + int64 start_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kStartMaxArgument, + proto, + &start_max)); + int64 end_min = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kEndMinArgument, + proto, + &end_min)); + int64 end_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kEndMaxArgument, + proto, + &end_max)); + int64 duration_min = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kDurationMinArgument, + proto, + &duration_min)); + int64 duration_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kDurationMaxArgument, + proto, + &duration_max)); + int64 optional = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kOptionalArgument, + proto, + &optional)); + VERIFY_EQ(duration_max, duration_min); + VERIFY_EQ(end_max - duration_max, start_max); + VERIFY_EQ(end_min - duration_min, start_min); + const string name = proto.name(); + if (start_min == start_max) { + return solver->MakeFixedInterval(start_min, duration_min, name); + } else { + return solver->MakeFixedDurationIntervalVar(start_min, + start_max, + duration_min, + optional, + name); + } + } else { + VERIFY_EQ(1, proto.arguments_size()); + const CPArgumentProto& sub_proto = proto.arguments(0); + IntervalVar* const derived = + builder->IntervalVariable(sub_proto.interval_index()); + const int operation_index = sub_proto.argument_index(); + DCHECK_NE(-1, operation_index); + if (operation_index == builder->TagIndex(ModelVisitor::kMirrorOperation)) { + return solver->MakeMirrorInterval(derived); + } else if (operation_index == + builder->TagIndex(ModelVisitor::kRelaxedMaxOperation)) { + solver->MakeIntervalRelaxedMax(derived); + } else if (operation_index == + builder->TagIndex(ModelVisitor::kRelaxedMinOperation)) { + solver->MakeIntervalRelaxedMin(derived); + } + } + return NULL; +} + +// ----- kIsBetween ----- + +Constraint* BuildIsBetween(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + int64 value_min = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMinArgument, proto, &value_min)); + int64 value_max = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kMaxArgument, proto, &value_max)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeIsBetweenCt(expr->Var(), + value_min, + value_max, + target->Var()); +} + +// ----- kIsDifferent ----- + +Constraint* BuildIsDifferent(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeIsDifferentCstCt(expr->Var(), + value, + target->Var()); +} + +// ----- kIsEqual ----- + +Constraint* BuildIsEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeIsEqualCstCt(expr->Var(), + value, + target->Var()); +} + +// ----- kIsGreaterOrEqual ----- + +Constraint* BuildIsGreaterOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeIsGreaterOrEqualCstCt(expr->Var(), + value, + target->Var()); +} + +// ----- kIsLessOrEqual ----- + +Constraint* BuildIsLessOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeIsLessOrEqualCstCt(expr->Var(), + value, + target->Var()); +} + +// ----- kIsMember ----- + +Constraint* BuildIsMember(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &values)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeIsMemberCt(expr->Var(), values, target->Var()); +} + +// ----- kLess ----- + +Constraint* BuildLess(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* left = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)); + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeLess(left->Var(), right->Var()); +} + +// ----- kLessOrEqual ----- + +Constraint* BuildLessOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeLessOrEqual(left->Var(), right->Var()); + } + IntExpr* expr = NULL; + if (builder->ScanArguments(ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeLessOrEqual(expr->Var(), value); + } + return NULL; +} + +// ----- kMapDomain ----- + +Constraint* BuildMapDomain(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeMapDomain(target->Var(), vars); +} + +// ----- kMax ----- + +IntExpr* BuildMax(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeMax(left, right); + } + IntExpr* expr = NULL; + if (builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeMax(expr, value); + } + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + return builder->solver()->MakeMax(vars); +} + +// ----- kMaxEqual ----- + +// TODO(user): Add API on solver and uncomment this method. +/* + Constraint* BuildMaxEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, + proto, + &vars)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, + proto, + &target)); + return builder->solver()->MakeMaxEqual(vars, target->Var()); + } +*/ + +// ----- kMember ----- + +Constraint* BuildMember(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &values)); + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + return builder->solver()->MakeMemberCt(expr->Var(), values); +} + +// ----- kMin ----- + +IntExpr* BuildMin(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeMin(left, right); + } + IntExpr* expr = NULL; + if (builder->ScanArguments(ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeMin(expr, value); + } + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + return builder->solver()->MakeMin(vars); +} + +// ----- kMinEqual ----- + +// TODO(user): Add API on solver and implement this method. + +// ----- kNoCycle ----- + +Constraint* BuildNoCycle(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector nexts; + VERIFY(builder->ScanArguments(ModelVisitor::kNextsArgument, proto, &nexts)); + std::vector active; + VERIFY(builder->ScanArguments(ModelVisitor::kActiveArgument, proto, &active)); + int64 assume_paths = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kAssumePathsArgument, + proto, + &assume_paths)); + ResultCallback1* sink_handler = NULL; + if (proto.extensions_size() > 0) { + VERIFY_EQ(1, proto.extensions_size()); + const int tag_index = + builder->TagIndex(ModelVisitor::kInt64ToBoolExtension); + MakeCallbackFromProto(builder, + proto.extensions(0), + tag_index, + &sink_handler); + } + return builder->solver()->MakeNoCycle(nexts, active, NULL, assume_paths); +} + +// ----- kNonEqual ----- + +Constraint* BuildNonEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeNonEquality(left->Var(), right->Var()); + } + IntExpr* expr = NULL; + if (builder->ScanArguments(ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeNonEquality(expr->Var(), value); + } + return NULL; +} + +// ----- kOpposite ----- + +IntExpr* BuildOpposite(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + return builder->solver()->MakeOpposite(expr); +} + +// ----- kPack ----- + +bool AddUsageLessConstantDimension(Pack* const pack, + CPModelBuilder* const builder, + const CPExtensionProto& proto) { + std::vector weights; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &weights)); + std::vector upper; + VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &upper)); + pack->AddWeightedSumLessOrEqualConstantDimension(weights, upper); + return true; +} + +bool AddCountAssignedItemsDimension(Pack* const pack, + CPModelBuilder* const builder, + const CPExtensionProto& proto) { + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + pack->AddCountAssignedItemsDimension(target->Var()); + return true; +} + +bool AddCountUsedBinDimension(Pack* const pack, + CPModelBuilder* const builder, + const CPExtensionProto& proto) { + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + pack->AddCountUsedBinDimension(target->Var()); + return true; +} + +bool AddUsageEqualVariableDimension(Pack* const pack, + CPModelBuilder* const builder, + const CPExtensionProto& proto) { + std::vector weights; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &weights)); + std::vector loads; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &loads)); + pack->AddWeightedSumEqualVarDimension(weights, loads); + return true; +} + +bool AddVariableUsageLessConstantDimension(Pack* const pack, + CPModelBuilder* const builder, + const CPExtensionProto& proto) { + std::vector uppers; + VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument, proto, &uppers)); + std::vector usages; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &usages)); + pack->AddSumVariableWeightsLessOrEqualConstantDimension(usages, uppers); + return true; +} + +bool AddWeightedSumOfAssignedDimension(Pack* const pack, + CPModelBuilder* const builder, + const CPExtensionProto& proto) { + std::vector weights; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &weights)); + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + pack->AddWeightedSumOfAssignedDimension(weights, target->Var()); + return true; +} + +#define IS_TYPE(index, builder, tag) \ + index == builder->TagIndex(ModelVisitor::tag) + +Constraint* BuildPack(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + int64 bins = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kSizeArgument, proto, &bins)); + Pack* const pack = builder->solver()->MakePack(vars, bins); + // Add dimensions. They are stored as extensions in the proto. + for (int i = 0; i < proto.extensions_size(); ++i) { + const CPExtensionProto& dimension_proto = proto.extensions(i); + const int type_index = dimension_proto.type_index(); + if (IS_TYPE(type_index, builder, kUsageLessConstantExtension)) { + VERIFY(AddUsageLessConstantDimension(pack, builder, dimension_proto)); + } else if (IS_TYPE(type_index, builder, kCountAssignedItemsExtension)) { + VERIFY(AddCountAssignedItemsDimension(pack, builder, dimension_proto)); + } else if (IS_TYPE(type_index, builder, kCountUsedBinsExtension)) { + VERIFY(AddCountUsedBinDimension(pack, builder, dimension_proto)); + } else if (IS_TYPE(type_index, builder, kUsageEqualVariableExtension)) { + VERIFY(AddUsageEqualVariableDimension(pack, builder, dimension_proto)); + } else if (IS_TYPE(type_index, + builder, + kVariableUsageLessConstantExtension)) { + VERIFY(AddVariableUsageLessConstantDimension(pack, + builder, + dimension_proto)); + } else if (IS_TYPE(type_index, + builder, + kWeightedSumOfAssignedEqualVariableExtension)) { + VERIFY(AddWeightedSumOfAssignedDimension(pack, builder, dimension_proto)); + } else { + LOG(ERROR) << "Unrecognized extension " << dimension_proto.DebugString(); + return NULL; + } + } + return pack; +} +#undef IS_TYPE + +// ----- kPathCumul ----- + +Constraint* BuildPathCumul(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector nexts; + VERIFY(builder->ScanArguments(ModelVisitor::kNextsArgument, proto, &nexts)); + std::vector active; + VERIFY(builder->ScanArguments(ModelVisitor::kActiveArgument, proto, &active)); + std::vector cumuls; + VERIFY(builder->ScanArguments(ModelVisitor::kCumulsArgument, proto, &cumuls)); + std::vector transits; + VERIFY(builder->ScanArguments(ModelVisitor::kTransitsArgument, + proto, + &transits)); + return builder->solver()->MakePathCumul(nexts, active, cumuls, transits); +} + +// ----- kPerformedExpr ----- + +IntExpr* BuildPerformedExpr(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntervalVar* var = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalArgument, proto, &var)); + return var->PerformedExpr(); +} + +// ----- kProduct ----- + +IntExpr* BuildProduct(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeProd(left, right); + } + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeProd(expr, value); +} + +// ----- kScalProd ----- + +IntExpr* BuildScalProd(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &values)); + return builder->solver()->MakeScalProd(vars, values); +} + +// ----- kScalProdEqual ----- + +Constraint* BuildScalProdEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &values)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeScalProdEquality(vars, values, value); +} + +// ----- kScalProdGreaterOrEqual ----- + +Constraint* BuildScalProdGreaterOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &values)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeScalProdGreaterOrEqual(vars, values, value); +} + +// ----- kScalProdLessOrEqual ----- + +Constraint* BuildScalProdLessOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kCoefficientsArgument, + proto, + &values)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeScalProdLessOrEqual(vars, values, value); +} + +// ----- kSemiContinuous ----- + +IntExpr* BuildSemiContinuous(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + int64 fixed_charge = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kFixedChargeArgument, + proto, + &fixed_charge)); + int64 step = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kStepArgument, proto, &step)); + return builder->solver()->MakeSemiContinuousExpr(expr, fixed_charge, step); +} + +// ----- kSequence ----- + +Constraint* BuildSequence(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalsArgument, + proto, + &vars)); + return builder->solver()->MakeSequence(vars, proto.name()); +} + +// ----- kSquare ----- + +IntExpr* BuildSquare(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kExpressionArgument, + proto, + &expr)); + return builder->solver()->MakeSquare(expr); +} + +// ----- kStartExpr ----- + +IntExpr* BuildStartExpr(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntervalVar* var = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIntervalArgument, proto, &var)); + return var->StartExpr(); +} + +// ----- kSum ----- + +IntExpr* BuildSum(CPModelBuilder* const builder, + const CPIntegerExpressionProto& proto) { + IntExpr* left = NULL; + if (builder->ScanArguments(ModelVisitor::kLeftArgument, proto, &left)) { + IntExpr* right = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kRightArgument, proto, &right)); + return builder->solver()->MakeSum(left, right); + } + IntExpr* expr = NULL; + if (builder->ScanArguments( + ModelVisitor::kExpressionArgument, proto, &expr)) { + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeSum(expr, value); + } + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + return builder->solver()->MakeSum(vars); +} + +// ----- kSumEqual ----- + +Constraint* BuildSumEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + int64 value = 0; + if (builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)) { + return builder->solver()->MakeSumEquality(vars, value); + } + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, proto, &target)); + return builder->solver()->MakeSumEquality(vars, target->Var()); +} + +// ----- kSumGreaterOrEqual ----- + +Constraint* BuildSumGreaterOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeSumGreaterOrEqual(vars, value); +} + +// ----- kSumLessOrEqual ----- + +Constraint* BuildSumLessOrEqual(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + int64 value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kValueArgument, proto, &value)); + return builder->solver()->MakeSumLessOrEqual(vars, value); +} + +// ----- kTransition ----- + +Constraint* BuildTransition(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector > tuples; + VERIFY(builder->ScanArguments(ModelVisitor::kTuplesArgument, proto, &tuples)); + int64 initial_state = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kInitialState, + proto, + &initial_state)); + std::vector final_states; + VERIFY(builder->ScanArguments(ModelVisitor::kFinalStatesArgument, + proto, + &final_states)); + + + + return builder->solver()->MakeTransitionConstraint(vars, + tuples, + initial_state, + final_states); +} + +// ----- kTrueConstraint ----- + +Constraint* BuildTrueConstraint(CPModelBuilder* const builder, + const CPConstraintProto& proto) { + return builder->solver()->MakeTrueConstraint(); +} + +#undef VERIFY +#undef VERIFY_EQ +} // namespace + +// ----- CPModelBuilder ----- + +bool CPModelBuilder::BuildFromProto(const CPIntegerExpressionProto& proto) { + const int index = proto.index(); + const int tag_index = proto.type_index(); + Solver::IntegerExpressionBuilder* const builder = + solver_->GetIntegerExpressionBuilder(tags_.Element(tag_index)); + if (!builder) { + return false; + } + IntExpr* const built = builder->Run(this, proto); + if (!built) { + return false; + } + expressions_.resize(std::max(static_cast(expressions_.size()), + index + 1)); + expressions_[index] = built; + return true; +} + +Constraint* CPModelBuilder::BuildFromProto(const CPConstraintProto& proto) { + const int tag_index = proto.type_index(); + Solver::ConstraintBuilder* const builder = + solver_->GetConstraintBuilder(tags_.Element(tag_index)); + if (!builder) { + return NULL; + } + Constraint* const built = builder->Run(this, proto); + return built; +} + +bool CPModelBuilder::BuildFromProto(const CPIntervalVariableProto& proto) { + const int index = proto.index(); + const int tag_index = proto.type_index(); + Solver::IntervalVariableBuilder* const builder = + solver_->GetIntervalVariableBuilder(tags_.Element(tag_index)); + if (!builder) { + return NULL; + } + IntervalVar* const built = builder->Run(this, proto); + if (!built) { + return false; + } + intervals_.resize(std::max(static_cast(intervals_.size()), index + 1)); + intervals_[index] = built; + return true; +} + +IntExpr* CPModelBuilder::IntegerExpression(int index) const { + CHECK_GE(index, 0); + CHECK_LT(index, expressions_.size()); + CHECK_NOTNULL(expressions_[index]); + return expressions_[index]; +} + +IntervalVar* CPModelBuilder::IntervalVariable(int index) const { + CHECK_GE(index, 0); + CHECK_LT(index, intervals_.size()); + CHECK_NOTNULL(intervals_[index]); + return intervals_[index]; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + int64* to_fill) { + if (arg_proto.argument_index() == type_index && + arg_proto.has_integer_value()) { + *to_fill = arg_proto.integer_value(); + return true; + } + return false; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + IntExpr** to_fill) { + if (arg_proto.argument_index() == type_index && + arg_proto.has_integer_expression_index()) { + const int expression_index = arg_proto.integer_expression_index(); + CHECK_NOTNULL(expressions_[expression_index]); + *to_fill = expressions_[expression_index]; + return true; + } + return false; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector* to_fill) { + if (arg_proto.argument_index() == type_index) { + const int values_size = arg_proto.integer_array_size(); + for (int j = 0; j < values_size; ++j) { + to_fill->push_back(arg_proto.integer_array(j)); + } + return true; + } + return false; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector >* to_fill) { + if (arg_proto.argument_index() == type_index && + arg_proto.has_integer_matrix()) { + to_fill->clear(); + const CPIntegerMatrixProto& matrix = arg_proto.integer_matrix(); + const int rows = matrix.rows(); + const int columns = matrix.columns(); + to_fill->resize(rows); + int counter = 0; + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < columns; ++j) { + const int64 value = matrix.values(counter++); + (*to_fill)[i].push_back(value); + } + } + CHECK_EQ(matrix.values_size(), counter); + return true; + } + return false; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector* to_fill) { + if (arg_proto.argument_index() == type_index) { + const int vars_size = arg_proto.integer_expression_array_size(); + for (int j = 0; j < vars_size; ++j) { + const int expression_index = arg_proto.integer_expression_array(j); + CHECK_NOTNULL(expressions_[expression_index]); + to_fill->push_back(expressions_[expression_index]->Var()); + } + return true; + } + return false; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + IntervalVar** to_fill) { + if (arg_proto.argument_index() == type_index && + arg_proto.has_interval_index()) { + const int interval_index = arg_proto.interval_index(); + CHECK_NOTNULL(intervals_[interval_index]); + *to_fill = intervals_[interval_index]; + return true; + } + return false; +} + +bool CPModelBuilder::ScanOneArgument(int type_index, + const CPArgumentProto& arg_proto, + std::vector* to_fill) { + if (arg_proto.argument_index() == type_index) { + const int vars_size = arg_proto.interval_array_size(); + for (int j = 0; j < vars_size; ++j) { + const int interval_index = arg_proto.interval_array(j); + CHECK_NOTNULL(intervals_[interval_index]); + to_fill->push_back(intervals_[interval_index]); + } + return true; + } + return false; +} + +// ----- Solver API ----- + +void Solver::ExportModel(const std::vector& monitors, + CPModelProto* const model_proto) const { + CHECK_NOTNULL(model_proto); + FirstPassVisitor first_pass; + Accept(&first_pass); + for (ConstIter > it(monitors); !it.at_end(); ++it) { + (*it)->Accept(&first_pass); + } + SecondPassVisitor second_pass(first_pass, model_proto); + for (ConstIter > it(monitors); !it.at_end(); ++it) { + (*it)->Accept(&second_pass); + } + Accept(&second_pass); +} + +void Solver::ExportModel(CPModelProto* const model_proto) const { + CHECK_NOTNULL(model_proto); + FirstPassVisitor first_pass; + Accept(&first_pass); + SecondPassVisitor second_pass(first_pass, model_proto); + Accept(&second_pass); +} + +bool Solver::LoadModel(const CPModelProto& model_proto) { + return LoadModel(model_proto, NULL); +} + +bool Solver::LoadModel(const CPModelProto& model_proto, + std::vector* monitors) { + if (model_proto.version() > kModelVersion) { + LOG(ERROR) << "Model protocol buffer version is greater than" + << " the one compiled in the reader (" + << model_proto.version() << " vs " << kModelVersion << ")"; + return false; + } + CPModelBuilder builder(this); + for (int i = 0; i < model_proto.tags_size(); ++i) { + builder.AddTag(model_proto.tags(i)); + } + for (int i = 0; i < model_proto.intervals_size(); ++i) { + if (!builder.BuildFromProto(model_proto.intervals(i))) { + LOG(ERROR) << "Interval variable proto " + << model_proto.intervals(i).DebugString() + << " was not parsed correctly"; + return false; + } + } + for (int i = 0; i < model_proto.expressions_size(); ++i) { + if (!builder.BuildFromProto(model_proto.expressions(i))) { + LOG(ERROR) << "Integer expression proto " + << model_proto.expressions(i).DebugString() + << " was not parsed correctly"; + return false; + } + } + for (int i = 0; i < model_proto.constraints_size(); ++i) { + Constraint* const constraint = + builder.BuildFromProto(model_proto.constraints(i)); + if (constraint == NULL) { + LOG(ERROR) << "Constraint proto " + << model_proto.constraints(i).DebugString() + << " was not parsed correctly"; + return false; + } + AddConstraint(constraint); + } + if (monitors != NULL) { + if (model_proto.has_search_limit()) { + monitors->push_back(MakeLimit(model_proto.search_limit())); + } + if (model_proto.has_objective()) { + const CPObjectiveProto& objective_proto = model_proto.objective(); + IntVar* const objective_var = + builder.IntegerExpression(objective_proto.objective_index())->Var(); + const bool maximize = objective_proto.maximize(); + const int64 step = objective_proto.step(); + OptimizeVar* const objective = + MakeOptimize(maximize, objective_var, step); + monitors->push_back(objective); + } + } + return true; +} + +bool Solver::UpgradeModel(CPModelProto* const proto) { + if (proto->version() == kModelVersion) { + LOG(INFO) << "Model already up to date with version " << kModelVersion; + } + return true; +} + +void Solver::RegisterBuilder(const string& tag, + ConstraintBuilder* const builder) { + InsertOrDie(&constraint_builders_, tag, builder); +} + +void Solver::RegisterBuilder(const string& tag, + IntegerExpressionBuilder* const builder) { + InsertOrDie(&expression_builders_, tag, builder); +} + +void Solver::RegisterBuilder(const string& tag, + IntervalVariableBuilder* const builder) { + InsertOrDie(&interval_builders_, tag, builder); +} + +Solver::ConstraintBuilder* +Solver::GetConstraintBuilder(const string& tag) const { + return FindPtrOrNull(constraint_builders_, tag); +} + +Solver::IntegerExpressionBuilder* +Solver::GetIntegerExpressionBuilder(const string& tag) const { + return FindPtrOrNull(expression_builders_, tag); +} + +Solver::IntervalVariableBuilder* +Solver::GetIntervalVariableBuilder(const string& tag) const { + IntervalVariableBuilder* const builder = + FindPtrOrNull(interval_builders_, tag); + return builder; +} + +// ----- Manage builders ----- + +#define REGISTER(tag, func) \ + RegisterBuilder(ModelVisitor::tag, NewPermanentCallback(&func)) + +void Solver::InitBuilders() { + REGISTER(kAbs, BuildAbs); + REGISTER(kAllDifferent, BuildAllDifferent); + REGISTER(kAllowedAssignments, BuildAllowedAssignments); + REGISTER(kBetween, BuildBetween); + REGISTER(kConvexPiecewise, BuildConvexPiecewise); + REGISTER(kCountEqual, BuildCountEqual); + REGISTER(kCumulative, BuildCumulative); + REGISTER(kDeviation, BuildDeviation); + REGISTER(kDifference, BuildDifference); + REGISTER(kDistribute, BuildDistribute); + REGISTER(kDivide, BuildDivide); + REGISTER(kDurationExpr, BuildDurationExpr); + REGISTER(kElement, BuildElement); + // REGISTER(kElementEqual, BuildElementEqual); + REGISTER(kEndExpr, BuildEndExpr); + REGISTER(kEquality, BuildEquality); + REGISTER(kFalseConstraint, BuildFalseConstraint); + REGISTER(kGreater, BuildGreater); + REGISTER(kGreaterOrEqual, BuildGreaterOrEqual); + REGISTER(kIntegerVariable, BuildIntegerVariable); + REGISTER(kIntervalBinaryRelation, BuildIntervalBinaryRelation); + REGISTER(kIntervalDisjunction, BuildIntervalDisjunction); + REGISTER(kIntervalUnaryRelation, BuildIntervalUnaryRelation); + REGISTER(kIntervalVariable, BuildIntervalVariable); + REGISTER(kIsBetween, BuildIsBetween); + REGISTER(kIsDifferent, BuildIsDifferent); + REGISTER(kIsEqual, BuildIsEqual); + REGISTER(kIsGreaterOrEqual, BuildIsGreaterOrEqual); + REGISTER(kIsLessOrEqual, BuildIsLessOrEqual); + REGISTER(kIsMember, BuildIsMember); + REGISTER(kLess, BuildLess); + REGISTER(kLessOrEqual, BuildLessOrEqual); + REGISTER(kMapDomain, BuildMapDomain); + REGISTER(kMax, BuildMax); + // REGISTER(kMaxEqual, BuildMaxEqual); + REGISTER(kMember, BuildMember); + REGISTER(kMin, BuildMin); + // REGISTER(kMinEqual, BuildMinEqual); + REGISTER(kNoCycle, BuildNoCycle); + REGISTER(kNonEqual, BuildNonEqual); + REGISTER(kOpposite, BuildOpposite); + REGISTER(kPack, BuildPack); + REGISTER(kPathCumul, BuildPathCumul); + REGISTER(kPerformedExpr, BuildPerformedExpr); + REGISTER(kProduct, BuildProduct); + REGISTER(kScalProd, BuildScalProd); + REGISTER(kScalProdEqual, BuildScalProdEqual); + REGISTER(kScalProdGreaterOrEqual, BuildScalProdGreaterOrEqual); + REGISTER(kScalProdLessOrEqual, BuildScalProdLessOrEqual); + REGISTER(kSemiContinuous, BuildSemiContinuous); + REGISTER(kSequence, BuildSequence); + REGISTER(kSquare, BuildSquare); + REGISTER(kStartExpr, BuildStartExpr); + REGISTER(kSum, BuildSum); + REGISTER(kSumEqual, BuildSumEqual); + REGISTER(kSumGreaterOrEqual, BuildSumGreaterOrEqual); + REGISTER(kSumLessOrEqual, BuildSumLessOrEqual); + REGISTER(kTransition, BuildTransition); + REGISTER(kTrueConstraint, BuildTrueConstraint); +} +#undef REGISTER + +void Solver::DeleteBuilders() { + STLDeleteValues(&expression_builders_); + STLDeleteValues(&constraint_builders_); + STLDeleteValues(&interval_builders_); +} +} // namespace operations_research diff --git a/constraint_solver/model.proto b/constraint_solver/model.proto new file mode 100644 index 0000000000..1f029cfa99 --- /dev/null +++ b/constraint_solver/model.proto @@ -0,0 +1,89 @@ +// Copyright 2010-2011 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; +import "constraint_solver/search_limit.proto"; + +package operations_research; + +message CPIntegerMatrixProto { + required int32 rows = 1; + required int32 columns = 2; + repeated int64 values = 3; +} + +// This message holds one argument of a constraint or expression. It +// is referenced by the argument_name. Only one field apart the name +// must be set. +message CPArgumentProto { + required int32 argument_index = 1; + optional int64 integer_value = 2; + repeated int64 integer_array = 3; + optional int32 integer_expression_index = 4; + repeated int32 integer_expression_array = 5; + optional int32 interval_index = 6; + repeated int32 interval_array = 7; + optional CPIntegerMatrixProto integer_matrix = 8; +} + +message CPExtensionProto { + required int32 type_index = 1; + repeated CPArgumentProto arguments = 2; +} + +message CPIntegerExpressionProto { + required int32 index = 1; + required int32 type_index = 2; + optional string name = 3; + repeated CPArgumentProto arguments = 4; + repeated CPExtensionProto extensions = 5; +} + +message CPIntervalVariableProto { + required int32 index = 1; + required int32 type_index = 2; + optional string name = 3; + repeated CPArgumentProto arguments = 4; +} + +message CPConstraintProto { + required int32 index = 1; + required int32 type_index = 2; + optional string name = 3; + repeated CPArgumentProto arguments = 4; + repeated CPExtensionProto extensions = 5; +} + +message CPObjectiveProto { + required bool maximize = 1; + required int64 step = 2; + required int32 objective_index = 3; +} + +message CPVariableGroup { + repeated CPArgumentProto arguments = 1; + optional string type = 2; +} + +message CPModelProto { + required string model = 1; + required int32 version = 2; + repeated string tags = 3; + repeated CPIntegerExpressionProto expressions = 4; + repeated CPIntervalVariableProto intervals = 5; + repeated CPConstraintProto constraints = 6; + optional CPObjectiveProto objective = 7; + optional SearchLimitProto search_limit = 8; + repeated CPVariableGroup variable_groups = 9; + optional string licence_text = 10; +} diff --git a/constraint_solver/pack.cc b/constraint_solver/pack.cc index bd4755f274..a3caf64a3f 100644 --- a/constraint_solver/pack.cc +++ b/constraint_solver/pack.cc @@ -1235,6 +1235,8 @@ void Pack::AddWeightedSumOfAssignedDimension(const std::vector& weights, void Pack::AddSumVariableWeightsLessOrEqualConstantDimension( const std::vector& usage, const std::vector& capacity) { + CHECK_EQ(usage.size(), vsize_); + CHECK_EQ(capacity.size(), bins_); Solver* const s = solver(); Dimension* const dim = s->RevAlloc(new VariableUsageDimension(s, diff --git a/constraint_solver/search.cc b/constraint_solver/search.cc index 368f266099..5d5f64d7d9 100644 --- a/constraint_solver/search.cc +++ b/constraint_solver/search.cc @@ -32,6 +32,7 @@ #include "base/random.h" #include "constraint_solver/constraint_solveri.h" #include "constraint_solver/search_limit.pb.h" +#include "util/string_array.h" DEFINE_bool(cp_use_sparse_gls_penalties, false, "Use sparse implementation to store Guided Local Search penalties"); diff --git a/examples/model_util.cc b/examples/model_util.cc new file mode 100644 index 0000000000..1a2ca14335 --- /dev/null +++ b/examples/model_util.cc @@ -0,0 +1,314 @@ +// Copyright 2010-2011 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/commandlineflags.h" +#include "base/commandlineflags.h" +#include "base/integral_types.h" +#include "base/logging.h" +#include "base/macros.h" +#include "base/file.h" +#include "base/recordio.h" +#include "constraint_solver/constraint_solver.h" +#include "constraint_solver/model.pb.h" + +DEFINE_string(input, "", "Input file of the problem."); +DEFINE_string(output, "", "Output file when doing modifications."); +DEFINE_string(dot_file, "", "Exports model to dot file."); + +DEFINE_bool(print_proto, false, "Prints the raw model protobuf."); +DEFINE_bool(test_proto, false, "Performs various tests on the model protobuf."); +DEFINE_bool(model_stats, false, "Prints model statistics."); +DEFINE_bool(print_model, false, "Pretty print loaded model."); + +DEFINE_string(rename_model, "", "Renames to the model."); +DEFINE_bool(strip_limit, false, "Strips limits from the model."); +DEFINE_bool(strip_groups, false, "Strips variable groups from the model."); +DEFINE_bool(upgrade_proto, false, "Upgrade the model to the latest version."); +DEFINE_string(insert_licence, "", + "Insert content of the given file into the licence file."); + +namespace operations_research { +static const int kProblem = -1; +static const int kOk = 0; + +// ----- Export to .dot file ----- + +// Appends a string to te file. +void Write(File* const file, const string& string) { + file->Write(string.c_str(), string.size()); +} + +// Adds one link in the generated graph. +void WriteExprLink(const string& origin, + int index, + const string& label, + File* const file) { + const string other = StringPrintf("expr_%i", index); + Write(file, StringPrintf("%s -- %s [label=%s]\n", + origin.c_str(), + other.c_str(), + label.c_str())); +} + +// Adds one link in the generated graph. +void WriteIntervalLink(const string& origin, + int index, + const string& label, + File* const file) { + const string other = StringPrintf("interval_%i", index); + Write(file, StringPrintf("%s -- %s [label=%s]\n", + origin.c_str(), + other.c_str(), + label.c_str())); +} + +// Scans argument to add links in the graph. +template void ExportLinks(const CPModelProto& model, + const string& origin, + const T& proto, + File* const file) { + const string& arg_name = model.tags(proto.argument_index()); + if (proto.has_integer_expression_index()) { + WriteExprLink(origin, proto.integer_expression_index(), arg_name, file); + } + for (int i = 0; i < proto.integer_expression_array_size(); ++i) { + WriteExprLink(origin, proto.integer_expression_array(i), arg_name, file); + } + if (proto.has_interval_index()) { + WriteIntervalLink(origin, proto.interval_index(), arg_name, file); + } + for (int i = 0; i < proto.interval_array_size(); ++i) { + WriteIntervalLink(origin, proto.interval_array(i), arg_name, file); + } +} + +// Declares a labelled expression in the .dot file. +void DeclareExpression(int index, const CPModelProto& proto, File* const file) { + const CPIntegerExpressionProto& expr = proto.expressions(index); + const string short_name = StringPrintf("expr_%i", index); + if (expr.has_name()) { + Write(file, StringPrintf("%s [shape=oval label=\"%s\" color=green]\n", + short_name.c_str(), + expr.name().c_str())); + } else { + const string& type = proto.tags(expr.type_index()); + Write(file, StringPrintf("%s [shape=oval label=\"%s\"]\n", + short_name.c_str(), + type.c_str())); + } +} + +void DeclareInterval(int index, const CPModelProto& proto, File* const file) { + const CPIntervalVariableProto& interval = proto.intervals(index); + const string short_name = StringPrintf("interval_%i", index); + if (interval.has_name()) { + Write(file, StringPrintf("%s [shape=circle label=\"%s\" color=green]\n", + short_name.c_str(), + interval.name().c_str())); + } else { + const string& type = proto.tags(interval.type_index()); + Write(file, StringPrintf("%s [shape=oval label=\"%s\"]\n", + short_name.c_str(), + type.c_str())); + } +} + +void DeclareConstraint(int index, const CPModelProto& proto, File* const file) { + const CPConstraintProto& ct = proto.constraints(index); + const string& type = proto.tags(ct.type_index()); + const string short_name = StringPrintf("ct_%i", index); + Write(file, StringPrintf("%s [shape=box label=\"%s\"]\n", + short_name.c_str(), + type.c_str())); +} + +// Parses the proto and exports it to a .dot file. +void ExportToDot(const CPModelProto& proto, File* const file) { + Write(file, StringPrintf("graph %s {\n", proto.model().c_str())); + + for (int i = 0; i < proto.expressions_size(); ++i) { + DeclareExpression(i, proto, file); + } + + for (int i = 0; i < proto.intervals_size(); ++i) { + DeclareInterval(i, proto, file); + } + + for (int i = 0; i < proto.constraints_size(); ++i) { + DeclareConstraint(i, proto, file); + } + + if (proto.has_objective()) { + if (proto.objective().maximize()) { + Write(file, "obj [shape=diamond label=\"Maximize\" color=red]\n"); + } else { + Write(file, "obj [shape=diamond label=\"Minimize\" color=red]\n"); + } + } + + for (int i = 0; i < proto.expressions_size(); ++i) { + const CPIntegerExpressionProto& expr = proto.expressions(i); + const string short_name = StringPrintf("expr_%i", i); + for (int j = 0; j < expr.arguments_size(); ++j) { + ExportLinks(proto, short_name, expr.arguments(j), file); + } + } + + for (int i = 0; i < proto.intervals_size(); ++i) { + const CPIntervalVariableProto& interval = proto.intervals(i); + const string short_name = StringPrintf("interval_%i", i); + for (int j = 0; j < interval.arguments_size(); ++j) { + ExportLinks(proto, short_name, interval.arguments(j), file); + } + } + + for (int i = 0; i < proto.constraints_size(); ++i) { + const CPConstraintProto& ct = proto.constraints(i); + const string short_name = StringPrintf("ct_%i", i); + for (int j = 0; j < ct.arguments_size(); ++j) { + ExportLinks(proto, short_name, ct.arguments(j), file); + } + } + + if (proto.has_objective()) { + const CPObjectiveProto& obj = proto.objective(); + WriteExprLink("obj", + obj.objective_index(), + ModelVisitor::kExpressionArgument, + file); + } + + Write(file, "}\n"); +} + +// ----- Main Method ----- + +int Run() { + // ----- Load input file into protobuf ----- + + File::Init(); + File* const file = File::Open(FLAGS_input, "r"); + if (file == NULL) { + LOG(WARNING) << "Cannot open " << FLAGS_input; + return kProblem; + } + + CPModelProto model_proto; + RecordReader reader(file); + if (!(reader.ReadProtocolMessage(&model_proto) && reader.Close())) { + LOG(INFO) << "No model found in " << file->CreateFileName(); + return kProblem; + } + + // ----- Display loaded protobuf ----- + + LOG(INFO) << "Read model " << model_proto.model(); + if (model_proto.has_licence_text()) { + LOG(INFO) << "Licence = " << model_proto.licence_text(); + } + + // ----- Modifications ----- + + if (!FLAGS_rename_model.empty()) { + model_proto.set_model(FLAGS_rename_model); + } + + if (FLAGS_strip_limit) { + model_proto.clear_search_limit(); + } + + if (FLAGS_strip_groups) { + model_proto.clear_variable_groups(); + } + + if (FLAGS_upgrade_proto) { + if (!Solver::UpgradeModel(&model_proto)) { + LOG(ERROR) << "Model upgrade failed"; + return kProblem; + } + } + + if (!FLAGS_insert_licence.empty()) { + File* const licence = File::Open(FLAGS_insert_licence, "r"); + if (licence == NULL) { + LOG(WARNING) << "Cannot open " << FLAGS_insert_licence; + return kProblem; + } + const int size = licence->Size(); + char* const text = new char[size + 1]; + licence->Read(text, size); + model_proto.set_licence_text(text); + licence->Close(); + } + + // ----- Reporting ----- + + if (FLAGS_print_proto) { + LOG(INFO) << model_proto.DebugString(); + } + if (FLAGS_test_proto || FLAGS_model_stats || FLAGS_print_model) { + Solver solver(model_proto.model()); + std::vector monitors; + if (!solver.LoadModel(model_proto, &monitors)) { + LOG(INFO) << "Could not load model into the solver"; + return kProblem; + } + if (FLAGS_test_proto) { + LOG(INFO) << "Model " << model_proto.model() << " loaded OK"; + } + if (FLAGS_model_stats) { + ModelVisitor* const visitor = solver.MakeStatisticsModelVisitor(); + solver.Accept(visitor, monitors); + } + if (FLAGS_print_model) { + ModelVisitor* const visitor = solver.MakePrintModelVisitor(); + solver.Accept(visitor, monitors); + } + } + + // ----- Output ----- + + if (!FLAGS_output.empty()) { + File* const output = File::Open(FLAGS_output, "w"); + if (output == NULL) { + LOG(INFO) << "Cannot open " << FLAGS_output; + return kProblem; + } + RecordWriter writer(output); + if (!(writer.WriteProtocolMessage(model_proto) && writer.Close())) { + return kProblem; + } else { + LOG(INFO) << "Model successfully written to " << FLAGS_output; + } + } + + if (!FLAGS_dot_file.empty()) { + File* const dot_file = File::Open(FLAGS_dot_file, "w"); + if (dot_file == NULL) { + LOG(INFO) << "Cannot open " << FLAGS_dot_file; + return kProblem; + } + ExportToDot(model_proto, dot_file); + dot_file->Close(); + } + return kOk; +} +} // namespace operations_research + +int main(int argc, char **argv) { + google::ParseCommandLineFlags(&argc, &argv, true); + if (FLAGS_input.empty()) { + LOG(FATAL) << "Filename not specified"; + } + return operations_research::Run(); +}