diff --git a/cmake/python.cmake b/cmake/python.cmake index 45c43f3ca7..ac3a7ef237 100644 --- a/cmake/python.cmake +++ b/cmake/python.cmake @@ -419,7 +419,6 @@ if(BUILD_MATH_OPT) endif() file(COPY ortools/sat/python/cp_model.py - ortools/sat/python/cp_model_numbers.py DESTINATION ${PYTHON_PROJECT_DIR}/sat/python) file(COPY ortools/sat/colab/flags.py @@ -699,6 +698,10 @@ add_custom_command( $ ${PYTHON_PROJECT}/routing/python COMMAND ${CMAKE_COMMAND} -E copy $ ${PYTHON_PROJECT}/sat/python + COMMAND ${CMAKE_COMMAND} -E copy + $ ${PYTHON_PROJECT}/sat/python + COMMAND ${CMAKE_COMMAND} -E copy + $ ${PYTHON_PROJECT}/sat/python COMMAND ${CMAKE_COMMAND} -E copy $ ${PYTHON_PROJECT}/scheduling/python COMMAND ${CMAKE_COMMAND} -E copy @@ -724,7 +727,9 @@ add_custom_command( $<$:math_opt_elemental_pybind11> $<$:math_opt_io_pybind11> $ + cp_model_builder_pybind cp_model_helper_pybind11 + sat_parameters_builder_pybind rcpsp_pybind11 set_cover_pybind11 sorted_interval_list_pybind11 diff --git a/examples/contrib/permutation_flow_shop.py b/examples/contrib/permutation_flow_shop.py index 505ce7dbfd..738d314399 100644 --- a/examples/contrib/permutation_flow_shop.py +++ b/examples/contrib/permutation_flow_shop.py @@ -27,7 +27,6 @@ import numpy as np from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _PARAMS = flags.DEFINE_string( @@ -149,7 +148,7 @@ def permutation_flow_shop( solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) solver.parameters.log_search_progress = log solver.parameters.max_time_in_seconds = time_limit diff --git a/examples/contrib/scheduling_with_transitions_sat.py b/examples/contrib/scheduling_with_transitions_sat.py index cbaaf78b94..abf4461a65 100644 --- a/examples/contrib/scheduling_with_transitions_sat.py +++ b/examples/contrib/scheduling_with_transitions_sat.py @@ -9,7 +9,6 @@ import argparse import collections from ortools.sat.python import cp_model -from google.protobuf import text_format #---------------------------------------------------------------------------- # Command line arguments. @@ -295,7 +294,7 @@ def main(args): solver = cp_model.CpSolver() solver.parameters.max_time_in_seconds = 60 * 60 * 2 if parameters: - text_format.Merge(parameters, solver.parameters) + solver.parameters.merge_text_format(parameters) solution_printer = SolutionPrinter(makespan) status = solver.Solve(model, solution_printer) diff --git a/examples/python/arc_flow_cutting_stock_sat.py b/examples/python/arc_flow_cutting_stock_sat.py index 9b70f2c2ca..2cad61f567 100644 --- a/examples/python/arc_flow_cutting_stock_sat.py +++ b/examples/python/arc_flow_cutting_stock_sat.py @@ -21,7 +21,6 @@ from absl import app from absl import flags import numpy as np -from google.protobuf import text_format from ortools.linear_solver.python import model_builder as mb from ortools.sat.python import cp_model @@ -319,7 +318,7 @@ def solve_cutting_stock_with_arc_flow_and_sat(output_proto_file: str, params: st # Solve model. solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) solver.parameters.log_search_progress = True solver.Solve(model) diff --git a/examples/python/bus_driver_scheduling_sat.py b/examples/python/bus_driver_scheduling_sat.py index 64f77118d7..73a1e55217 100644 --- a/examples/python/bus_driver_scheduling_sat.py +++ b/examples/python/bus_driver_scheduling_sat.py @@ -30,7 +30,6 @@ import math from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _OUTPUT_PROTO = flags.DEFINE_string( @@ -1982,7 +1981,7 @@ def bus_driver_scheduling(minimize_drivers: bool, max_num_drivers: int) -> int: # Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) status = solver.solve(model) diff --git a/examples/python/golomb_sat.py b/examples/python/golomb_sat.py index 0ed2240a6b..c785649d91 100644 --- a/examples/python/golomb_sat.py +++ b/examples/python/golomb_sat.py @@ -28,7 +28,6 @@ from typing import Sequence from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _ORDER = flags.DEFINE_integer("order", 8, "Order of the ruler.") @@ -71,7 +70,7 @@ def solve_golomb_ruler(order: int, params: str) -> None: # Solve the model. solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solution_printer = cp_model.ObjectiveSolutionPrinter() print(f"Golomb ruler(order={order})") status = solver.solve(model, solution_printer) diff --git a/examples/python/knapsack_2d_sat.py b/examples/python/knapsack_2d_sat.py index 5014d246e8..e66457a7ce 100644 --- a/examples/python/knapsack_2d_sat.py +++ b/examples/python/knapsack_2d_sat.py @@ -25,7 +25,6 @@ from absl import flags import numpy as np import pandas as pd -from google.protobuf import text_format from ortools.sat.python import cp_model @@ -158,7 +157,7 @@ def solve_with_duplicate_items( # Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) status = solver.solve(model) @@ -260,7 +259,7 @@ def solve_with_duplicate_optional_items( # solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) status = solver.solve(model) @@ -381,7 +380,7 @@ def solve_with_rotations(data: pd.Series, max_height: int, max_width: int): # solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) status = solver.solve(model) diff --git a/examples/python/line_balancing_sat.py b/examples/python/line_balancing_sat.py index 5cb513c52b..2e8c5e91c9 100644 --- a/examples/python/line_balancing_sat.py +++ b/examples/python/line_balancing_sat.py @@ -34,10 +34,9 @@ from typing import Dict, Sequence from absl import app from absl import flags -from google.protobuf import text_format - from ortools.sat.python import cp_model + _INPUT = flags.DEFINE_string("input", "", "Input file to parse and solve.") _PARAMS = flags.DEFINE_string("params", "", "Sat solver parameters.") _OUTPUT_PROTO = flags.DEFINE_string( @@ -273,7 +272,7 @@ def solve_problem_with_boolean_model( # solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solver.parameters.log_search_progress = True solver.solve(model) @@ -340,7 +339,7 @@ def solve_problem_with_scheduling_model( # solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solver.parameters.log_search_progress = True solver.solve(model) diff --git a/examples/python/maze_escape_sat.py b/examples/python/maze_escape_sat.py index 6d5e9c4796..30d2cf5042 100644 --- a/examples/python/maze_escape_sat.py +++ b/examples/python/maze_escape_sat.py @@ -26,7 +26,6 @@ from typing import Dict, Sequence, Tuple from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _OUTPUT_PROTO = flags.DEFINE_string( @@ -141,7 +140,7 @@ def escape_the_maze(params: str, output_proto: str) -> None: # Solve model. solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) solver.parameters.log_search_progress = True result = solver.solve(model) diff --git a/examples/python/memory_layout_and_infeasibility_sat.py b/examples/python/memory_layout_and_infeasibility_sat.py index 9956700c22..77b11e82c5 100644 --- a/examples/python/memory_layout_and_infeasibility_sat.py +++ b/examples/python/memory_layout_and_infeasibility_sat.py @@ -20,7 +20,6 @@ from typing import List from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model @@ -72,7 +71,7 @@ def solve_hard_model(output_proto: str, params: str) -> bool: solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) status = solver.solve(model) print(solver.response_stats()) @@ -158,7 +157,7 @@ def solve_soft_model_with_maximization(params: str) -> None: solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) status = solver.solve(model) print(solver.response_stats()) if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE: diff --git a/examples/python/no_wait_baking_scheduling_sat.py b/examples/python/no_wait_baking_scheduling_sat.py index 9c81d75674..56b869f1f8 100644 --- a/examples/python/no_wait_baking_scheduling_sat.py +++ b/examples/python/no_wait_baking_scheduling_sat.py @@ -26,7 +26,6 @@ from typing import List, Sequence, Tuple from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _PARAMS = flags.DEFINE_string( @@ -287,7 +286,7 @@ def solve_with_cp_sat( # Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solver.parameters.log_search_progress = True status = solver.solve(model) diff --git a/examples/python/pentominoes_sat.py b/examples/python/pentominoes_sat.py index 01479ec6ac..4e041d3272 100644 --- a/examples/python/pentominoes_sat.py +++ b/examples/python/pentominoes_sat.py @@ -31,7 +31,6 @@ from typing import Dict, List from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model @@ -144,7 +143,7 @@ def generate_and_solve_problem(pieces: Dict[str, List[List[int]]]) -> None: # Solve the model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) status = solver.solve(model) print( diff --git a/examples/python/rcpsp_sat.py b/examples/python/rcpsp_sat.py index 2b78e3d049..aacfb7d523 100644 --- a/examples/python/rcpsp_sat.py +++ b/examples/python/rcpsp_sat.py @@ -26,10 +26,9 @@ import collections from absl import app from absl import flags -from google.protobuf import text_format -from ortools.sat.python import cp_model from ortools.scheduling import rcpsp_pb2 from ortools.scheduling.python import rcpsp +from ortools.sat.python import cp_model _INPUT = flags.DEFINE_string("input", "", "Input file to parse and solve.") _OUTPUT_PROTO = flags.DEFINE_string( @@ -361,7 +360,7 @@ def solve_rcpsp( # Parse user specified parameters. if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) # Favor objective_shaving over objective_lb_search. if solver.parameters.num_workers >= 16 and solver.parameters.num_workers < 24: diff --git a/examples/python/shift_scheduling_sat.py b/examples/python/shift_scheduling_sat.py index a81c083991..9883124176 100644 --- a/examples/python/shift_scheduling_sat.py +++ b/examples/python/shift_scheduling_sat.py @@ -17,7 +17,6 @@ from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _OUTPUT_PROTO = flags.DEFINE_string( @@ -410,7 +409,7 @@ def solve_shift_scheduling(params: str, output_proto: str): # Solve the model. solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) solution_printer = cp_model.ObjectiveSolutionPrinter() status = solver.solve(model, solution_printer) diff --git a/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py b/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py index c54a67d26a..d1d2039b86 100644 --- a/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py +++ b/examples/python/single_machine_scheduling_with_setup_release_due_dates_sat.py @@ -17,7 +17,6 @@ from typing import Sequence from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model # ---------------------------------------------------------------------------- @@ -498,7 +497,7 @@ def single_machine_scheduling(): # Solve. solver = cp_model.CpSolver() if parameters: - text_format.Parse(parameters, solver.parameters) + solver.parameters.parse_text_format(parameters) solution_printer = SolutionPrinter() solver.best_bound_callback = lambda a: print(f"New objective lower bound: {a}") solver.solve(model, solution_printer) diff --git a/examples/python/spread_robots_sat.py b/examples/python/spread_robots_sat.py index b9fc5999e6..476de60c68 100644 --- a/examples/python/spread_robots_sat.py +++ b/examples/python/spread_robots_sat.py @@ -18,7 +18,6 @@ import math from typing import Sequence from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _NUM_ROBOTS = flags.DEFINE_integer("num_robots", 8, "Number of robots to place.") @@ -93,7 +92,7 @@ def spread_robots(num_robots: int, room_size: int, params: str) -> None: # Creates a solver and solves the model. solver = cp_model.CpSolver() if params: - text_format.Parse(params, solver.parameters) + solver.parameters.parse_text_format(params) solver.parameters.log_search_progress = True status = solver.solve(model) diff --git a/examples/python/steel_mill_slab_sat.py b/examples/python/steel_mill_slab_sat.py index e84b490cf9..6be2e85287 100644 --- a/examples/python/steel_mill_slab_sat.py +++ b/examples/python/steel_mill_slab_sat.py @@ -22,7 +22,6 @@ import time from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model @@ -294,7 +293,7 @@ def steel_mill_slab(problem_id: int, break_symmetries: bool) -> None: ### Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) objective_printer = cp_model.ObjectiveSolutionPrinter() status = solver.solve(model, objective_printer) @@ -478,7 +477,7 @@ def steel_mill_slab_with_valid_slabs(problem_id: int, break_symmetries: bool) -> ### Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solution_printer = SteelMillSlabSolutionPrinter(orders, assign, loads, losses) status = solver.solve(model, solution_printer) @@ -548,7 +547,7 @@ def steel_mill_slab_with_column_generation(problem_id: int) -> None: ### Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solution_printer = cp_model.ObjectiveSolutionPrinter() status = solver.solve(model, solution_printer) diff --git a/examples/python/sudoku_sat.py b/examples/python/sudoku_sat.py index 069322242c..a66425eae9 100755 --- a/examples/python/sudoku_sat.py +++ b/examples/python/sudoku_sat.py @@ -68,8 +68,12 @@ def solve_sudoku() -> None: if initial_grid[i][j]: model.add(grid[(i, j)] == initial_grid[i][j]) + model.export_to_file('/tmp/sudoku_sat.pb.txt') + # Solves and prints out the solution. solver = cp_model.CpSolver() + solver.parameters.num_workers = 1 + solver.parameters.log_search_progress = True status = solver.solve(model) if status == cp_model.OPTIMAL: for i in line: diff --git a/examples/python/test_scheduling_sat.py b/examples/python/test_scheduling_sat.py index fdec2ba13b..e3feb1cfed 100644 --- a/examples/python/test_scheduling_sat.py +++ b/examples/python/test_scheduling_sat.py @@ -33,7 +33,6 @@ from absl import app from absl import flags import pandas as pd -from google.protobuf import text_format from ortools.sat.python import cp_model @@ -141,7 +140,7 @@ def solve( # Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) status = solver.solve(model) # Report solution. diff --git a/examples/python/weighted_latency_problem_sat.py b/examples/python/weighted_latency_problem_sat.py index 0abc315afe..fbc1bdd173 100644 --- a/examples/python/weighted_latency_problem_sat.py +++ b/examples/python/weighted_latency_problem_sat.py @@ -20,7 +20,6 @@ from typing import Sequence from absl import app from absl import flags -from google.protobuf import text_format from ortools.sat.python import cp_model _NUM_NODES = flags.DEFINE_integer("num_nodes", 12, "Number of nodes to visit.") @@ -99,7 +98,7 @@ def solve_with_cp_sat(x, y, profits) -> None: # Solve model. solver = cp_model.CpSolver() if _PARAMS.value: - text_format.Parse(_PARAMS.value, solver.parameters) + solver.parameters.parse_text_format(_PARAMS.value) solver.parameters.log_search_progress = True solver.solve(model) diff --git a/ortools/python/BUILD.bazel b/ortools/python/BUILD.bazel index 56b9e6c65d..aa06a88222 100644 --- a/ortools/python/BUILD.bazel +++ b/ortools/python/BUILD.bazel @@ -24,6 +24,8 @@ py_binary( "//ortools/graph/python:max_flow.so", "//ortools/graph/python:min_cost_flow.so", "//ortools/sat/python:cp_model_helper.so", + "//ortools/sat/python:cp_model_builder_pybind.so", + "//ortools/sat/python:sat_parameters_builder_pybind.so", ], tags = ["manual"], deps = [ @@ -32,7 +34,6 @@ py_binary( "//ortools/sat/colab:flags", "//ortools/sat/colab:visualization", "//ortools/sat/python:cp_model", - "//ortools/sat/python:cp_model_numbers", requirement("notebook"), requirement("svgwrite"), requirement("plotly"), diff --git a/ortools/python/setup.py.in b/ortools/python/setup.py.in index d148a1b0a3..7a240263e9 100644 --- a/ortools/python/setup.py.in +++ b/ortools/python/setup.py.in @@ -126,6 +126,8 @@ setup( '@PYTHON_PROJECT@.sat.colab':['*.pyi', 'py.typed'], '@PYTHON_PROJECT@.sat.python':[ '$', + '$', + '$', '*.pyi', 'py.typed' ], diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 7c4a6e67e8..3b76b64d43 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -17,7 +17,7 @@ load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@protobuf//bazel:java_proto_library.bzl", "java_proto_library") load("@protobuf//bazel:proto_library.bzl", "proto_library") load("@protobuf//bazel:py_proto_library.bzl", "py_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") +load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test", "cc_shared_library") load("@rules_go//proto:def.bzl", "go_proto_library") package(default_visibility = ["//visibility:public"]) @@ -220,6 +220,7 @@ cc_test( cc_proto_library( name = "cp_model_cc_proto", deps = [":cp_model_proto"], + linkshared = True, ) py_proto_library( @@ -3578,7 +3579,6 @@ cc_library( ":synchronization", ":timetable", ":util", - "//ortools/base:stl_util", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", diff --git a/ortools/sat/constraint_violation.cc b/ortools/sat/constraint_violation.cc index 4f41c10499..f59bff555f 100644 --- a/ortools/sat/constraint_violation.cc +++ b/ortools/sat/constraint_violation.cc @@ -1500,7 +1500,7 @@ LsEvaluator::LsEvaluator(const CpModelProto& cp_model, LsEvaluator::LsEvaluator( const CpModelProto& cp_model, const SatParameters& params, const std::vector& ignored_constraints, - const std::vector& additional_constraints, + absl::Span additional_constraints, TimeLimit* time_limit) : cp_model_(cp_model), params_(params), time_limit_(time_limit) { var_to_constraints_.resize(cp_model_.variables_size()); diff --git a/ortools/sat/constraint_violation.h b/ortools/sat/constraint_violation.h index 768fde9514..cc09718d24 100644 --- a/ortools/sat/constraint_violation.h +++ b/ortools/sat/constraint_violation.h @@ -313,7 +313,7 @@ class LsEvaluator { TimeLimit* time_limit); LsEvaluator(const CpModelProto& cp_model, const SatParameters& params, const std::vector& ignored_constraints, - const std::vector& additional_constraints, + absl::Span additional_constraints, TimeLimit* time_limit); // Intersects the domain of the objective with [lb..ub]. diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index e6b29ef50b..f95d2c7c0e 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -726,7 +726,6 @@ absl::flat_hash_map GetNamedParameters( SatParameters new_params = base_params; new_params.set_use_shared_tree_search(true); new_params.set_search_branching(SatParameters::AUTOMATIC_SEARCH); - new_params.set_linearization_level(0); // These settings don't make sense with shared tree search, turn them off as // they can break things. diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index 6af621ec36..efdf54130b 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -37,7 +37,7 @@ #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" -ABSL_FLAG(bool, cp_model_dump_models, false, +ABSL_FLAG(bool, cp_model_dump_models, true, "DEBUG ONLY. When set to true, SolveCpModel() will dump its model " "protos (original model, presolved model, mapping model) in text " "format to 'FLAGS_cp_model_dump_prefix'{model|presolved_model|" @@ -955,6 +955,7 @@ void RegisterFieldPrinters( if (field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { if (field->message_type() == IntegerVariableProto::descriptor() || field->message_type() == LinearExpressionProto::descriptor()) { + LOG(INFO) << "########### Register printer"; printer->RegisterFieldValuePrinter(field, new InlineFieldPrinter()); } else { RegisterFieldPrinters(field->message_type(), descriptors, printer); diff --git a/ortools/sat/feasibility_pump.cc b/ortools/sat/feasibility_pump.cc index 8eb2f49b19..6f2d9a1dc1 100644 --- a/ortools/sat/feasibility_pump.cc +++ b/ortools/sat/feasibility_pump.cc @@ -681,7 +681,9 @@ bool FeasibilityPump::PropagationRounding() { } if (!sat_solver_->FinishPropagation()) return false; - sat_solver_->EnqueueDecisionAndBacktrackOnConflict(to_enqueue); + const SatSolver::Status decision_status = + sat_solver_->EnqueueDecisionAndBacktrackOnConflict(to_enqueue); + if (decision_status != SatSolver::Status::FEASIBLE) return false; if (sat_solver_->ModelIsUnsat()) return false; } integer_solution_is_set_ = true; diff --git a/ortools/sat/go/cpmodel/cp_model.go b/ortools/sat/go/cpmodel/cp_model.go index 99dbf687ca..a535284af8 100644 --- a/ortools/sat/go/cpmodel/cp_model.go +++ b/ortools/sat/go/cpmodel/cp_model.go @@ -29,6 +29,7 @@ import ( "sort" log "github.com/golang/glog" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/go/cpmodel/cp_model_test.go b/ortools/sat/go/cpmodel/cp_model_test.go index 4d8ecf7e86..d2ed3d821a 100644 --- a/ortools/sat/go/cpmodel/cp_model_test.go +++ b/ortools/sat/go/cpmodel/cp_model_test.go @@ -22,8 +22,9 @@ import ( log "github.com/golang/glog" "github.com/google/go-cmp/cmp" - cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" "google.golang.org/protobuf/testing/protocmp" + + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func Example() { diff --git a/ortools/sat/opb_reader.h b/ortools/sat/opb_reader.h index cb6a36d500..ef452c2e7f 100644 --- a/ortools/sat/opb_reader.h +++ b/ortools/sat/opb_reader.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -87,6 +88,7 @@ class OpbReader { LOG(INFO) << "#variables: " << num_variables_; LOG(INFO) << "#constraints: " << constraints_.size(); LOG(INFO) << "#objective: " << objective_.size(); + if (top_cost_.has_value()) LOG(INFO) << "top_cost: " << top_cost_.value(); const std::string error_message = ValidateModel(); if (!error_message.empty()) { @@ -134,14 +136,16 @@ class OpbReader { void ProcessNewLine(const std::string& line) { const std::vector words = absl::StrSplit(line, absl::ByAnyChar(" ;"), absl::SkipEmpty()); - if (words.empty() || words[0].empty() || words[0][0] == '*') { - // TODO(user): Parse comments. + if (words.empty() || words[0].empty() || words[0][0] == '*') return; + + if (words[0] == "soft:") { + if (words.size() == 1) return; + int64_t top_cost; + if (!ParseInt64Into(words[1], &top_cost)) return; + top_cost_ = top_cost; return; } - // We ignore the number of soft constraints. - if (words[0] == "soft:") return; - if (words[0] == "min:") { for (int i = 1; i < words.size(); ++i) { const std::string& word = words[i]; @@ -364,6 +368,12 @@ class OpbReader { obj->add_coeffs(term.coeff); } } + + if (top_cost_.has_value()) { + CpObjectiveProto* obj = model->mutable_objective(); + obj->add_domain(std::numeric_limits::min()); + obj->add_domain(top_cost_.value()); + } } int num_variables_; @@ -371,6 +381,7 @@ class OpbReader { std::vector constraints_; absl::flat_hash_map, int> product_to_var_; bool model_is_supported_ = true; + std::optional top_cost_; }; } // namespace sat diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index 36af16ffab..7c0de26f54 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -141,6 +141,8 @@ std::string ValidateParameters(const SatParameters& params) { TEST_POSITIVE(glucose_decay_increment_period); TEST_POSITIVE(shared_tree_max_nodes_per_worker); TEST_POSITIVE(shared_tree_open_leaves_per_worker); + TEST_NON_NEGATIVE(shared_tree_split_min_dtime); + TEST_IS_FINITE(shared_tree_split_min_dtime); TEST_POSITIVE(mip_var_scaling); // Test LP tolerances. diff --git a/ortools/sat/primary_variables.cc b/ortools/sat/primary_variables.cc index 3071139260..c32b1b0654 100644 --- a/ortools/sat/primary_variables.cc +++ b/ortools/sat/primary_variables.cc @@ -411,6 +411,21 @@ VariableRelationships ComputeVariableRelationships(const CpModelProto& model) { -num_times_variable_appears_as_preferred_to_deduce[b], -num_times_variable_appears_as_deducible[b]); }); + + // Put in front of the queue all the variables that can readily be deduced + // using some constraint. + for (int c = 0; c < model.constraints_size(); ++c) { + ConstraintData& data = constraint_data[c]; + if (data.input_vars.size() + data.deducible_vars.size() != 1) { + continue; + } + if (!data.deducible_vars.empty()) { + vars_queue.push_front(*data.deducible_vars.begin()); + } else if (data.is_linear_inequality) { + vars_queue.push_front(*data.input_vars.begin()); + } + } + std::vector constraints_to_check; while (!vars_queue.empty()) { const int v = vars_queue.front(); diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index d73f0fa96f..20c4da59f1 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -28,6 +28,7 @@ cc_library( hdrs = ["linear_expr.h"], deps = [ "//ortools/sat:cp_model_cc_proto", + "//ortools/sat:cp_model_utils", "//ortools/util:fp_roundtrip_conv", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/container:btree", @@ -38,6 +39,110 @@ cc_library( ], ) +cc_library( + name = "wrappers", + srcs = ["wrappers.cc"], + hdrs = ["wrappers.h"], + deps = [ + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/log:die_if_null", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:span", + "@protobuf", + ], +) + +cc_binary( + name = "gen_cp_model_builder_pybind", + srcs = ["gen_cp_model_builder_pybind.cc"], + deps = [ + ":wrappers", + "//ortools/base", + "//ortools/sat:cp_model_cc_proto", + "@abseil-cpp//absl/log:die_if_null", + "@abseil-cpp//absl/strings:str_format", + ], +) + +genrule( + name = "run_gen_cp_model_builder_pybind", + outs = ["cp_model_builder_pybind.cc"], + cmd = "$(location :gen_cp_model_builder_pybind) > $@", + tools = [":gen_cp_model_builder_pybind"], +) + +pybind_extension( + name = "cp_model_builder", + srcs = [ + "cp_model_builder_pybind.cc", + ], + visibility = ["//visibility:public"], + deps = [ + "//ortools/port:proto_utils", + "//ortools/sat:cp_model_cc_proto", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/strings", + "@protobuf", + ], +) + +py_test( + name = "cp_model_builder_test", + srcs = ["cp_model_builder_test.py"], + deps = [ + ":cp_model_builder", + requirement("absl-py"), + "//ortools/sat:cp_model_py_pb2", + ], +) + +cc_binary( + name = "gen_sat_parameters_builder_pybind", + srcs = ["gen_sat_parameters_builder_pybind.cc"], + deps = [ + ":wrappers", + "//ortools/base", + "//ortools/sat:sat_parameters_cc_proto", + "@abseil-cpp//absl/log:die_if_null", + "@abseil-cpp//absl/strings:str_format", + ], +) + +genrule( + name = "run_gen_sat_parameters_builder_pybind", + outs = ["sat_parameters_builder_pybind.cc"], + cmd = "$(location :gen_sat_parameters_builder_pybind) > $@", + tools = [":gen_sat_parameters_builder_pybind"], +) + +pybind_extension( + name = "sat_parameters_builder", + srcs = [ + "sat_parameters_builder_pybind.cc", + ], + visibility = ["//visibility:public"], + deps = [ + "//ortools/port:proto_utils", + "//ortools/sat:sat_parameters_cc_proto", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/strings", + "@protobuf", + ], +) + +py_test( + name = "sat_parameters_builder_test", + srcs = ["sat_parameters_builder_test.py"], + deps = [ + ":sat_parameters_builder", + requirement("absl-py"), + ], +) + pybind_extension( name = "cp_model_helper", srcs = ["cp_model_helper.cc"], @@ -50,7 +155,6 @@ pybind_extension( "//ortools/sat:sat_parameters_cc_proto", "//ortools/sat:swig_helper", "@abseil-cpp//absl/strings", - "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", ], ) @@ -58,45 +162,24 @@ py_test( name = "cp_model_helper_test", srcs = ["cp_model_helper_test.py"], deps = [ + ":cp_model_builder", ":cp_model_helper", - "//ortools/sat:cp_model_py_pb2", - "//ortools/sat:sat_parameters_py_pb2", + ":sat_parameters_builder", "//ortools/util/python:sorted_interval_list", requirement("absl-py"), ], ) -py_library( - name = "cp_model_numbers", - srcs = ["cp_model_numbers.py"], - visibility = ["//visibility:public"], - deps = [ - ":cp_model_helper", - requirement("numpy"), - "@protobuf//:protobuf_python", - ], -) - -py_test( - name = "cp_model_numbers_test", - srcs = ["cp_model_numbers_test.py"], - deps = [ - ":cp_model_numbers", - requirement("absl-py"), - ], -) - py_library( name = "cp_model", srcs = ["cp_model.py"], visibility = ["//visibility:public"], deps = [ + ":cp_model_builder", ":cp_model_helper", - ":cp_model_numbers", + ":sat_parameters_builder", requirement("numpy"), requirement("pandas"), - "//ortools/sat:cp_model_py_pb2", - "//ortools/sat:sat_parameters_py_pb2", "//ortools/util/python:sorted_interval_list", ], ) diff --git a/ortools/sat/python/CMakeLists.txt b/ortools/sat/python/CMakeLists.txt index 6e91bfdefb..504c457325 100644 --- a/ortools/sat/python/CMakeLists.txt +++ b/ortools/sat/python/CMakeLists.txt @@ -11,6 +11,135 @@ # See the License for the specific language governing permissions and # limitations under the License. +set(WRAPPERS_NAME sat_python_wrappers) + +add_library(${WRAPPERS_NAME} OBJECT wrappers.h wrappers.cc) +set_target_properties(${WRAPPERS_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON) +target_include_directories(${WRAPPERS_NAME} PUBLIC + ${PROJECT_SOURCE_DIR} + ${PROJECT_BINARY_DIR}) +target_link_libraries(${WRAPPERS_NAME} PUBLIC + absl::memory + absl::synchronization + absl::str_format + protobuf::libprotobuf) +add_library(${PROJECT_NAMESPACE}::${WRAPPERS_NAME} ALIAS ${WRAPPERS_NAME}) + +# gen_cp_model_builder_pybind +add_executable(gen_cp_model_builder_pybind) +target_sources(gen_cp_model_builder_pybind PRIVATE "gen_cp_model_builder_pybind.cc") +target_include_directories(gen_cp_model_builder_pybind PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_compile_features(gen_cp_model_builder_pybind PRIVATE cxx_std_17) +target_link_libraries(gen_cp_model_builder_pybind PRIVATE + absl::flags_commandlineflag + absl::flags_parse + absl::flags_usage + absl::die_if_null + absl::str_format + protobuf::libprotobuf + ${PROJECT_NAMESPACE}::ortools_proto + ${PROJECT_NAMESPACE}::${WRAPPERS_NAME}) + +include(GNUInstallDirs) +if(APPLE) + set_target_properties(gen_cp_model_builder_pybind PROPERTIES INSTALL_RPATH + "@loader_path/../${CMAKE_INSTALL_LIBDIR};@loader_path") +elseif(UNIX) + cmake_path(RELATIVE_PATH CMAKE_INSTALL_FULL_LIBDIR + BASE_DIRECTORY ${CMAKE_INSTALL_FULL_BINDIR} + OUTPUT_VARIABLE libdir_relative_path) + set_target_properties(gen_cp_model_builder_pybind PROPERTIES + INSTALL_RPATH "$ORIGIN/${libdir_relative_path}") +endif() + +install(TARGETS gen_cp_model_builder_pybind) + +add_custom_command( + OUTPUT cp_model_builder_pybind.cc + COMMAND gen_cp_model_builder_pybind > cp_model_builder_pybind.cc + # DEPENDS ${PROTO_FILE} ${PROTOC_PRG} + COMMENT "Generate C++ protocol buffer for ${PROTO_FILE}" + VERBATIM) + + +pybind11_add_module(cp_model_builder_pybind MODULE cp_model_builder_pybind.cc) +set_target_properties(cp_model_builder_pybind PROPERTIES + LIBRARY_OUTPUT_NAME "cp_model_builder") + +# note: macOS is APPLE and also UNIX ! +if(APPLE) + set_target_properties(cp_model_builder_pybind PROPERTIES + SUFFIX ".so" + INSTALL_RPATH "@loader_path;@loader_path/../../../${PYTHON_PROJECT}/.libs") +elseif(UNIX) + set_target_properties(cp_model_builder_pybind PROPERTIES + INSTALL_RPATH "$ORIGIN:$ORIGIN/../../../${PYTHON_PROJECT}/.libs") +endif() +target_link_libraries(cp_model_builder_pybind PRIVATE + ${PROJECT_NAMESPACE}::ortools + protobuf::libprotobuf) + +target_include_directories(cp_model_builder_pybind PRIVATE ${protobuf_SOURCE_DIR}) +add_library(${PROJECT_NAMESPACE}::cp_model_builder_pybind ALIAS cp_model_builder_pybind) + +# gen_sat_parameters_builder_pybind +add_executable(gen_sat_parameters_builder_pybind) +target_sources(gen_sat_parameters_builder_pybind PRIVATE "gen_sat_parameters_builder_pybind.cc") +target_include_directories(gen_sat_parameters_builder_pybind PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_compile_features(gen_sat_parameters_builder_pybind PRIVATE cxx_std_17) +target_link_libraries(gen_sat_parameters_builder_pybind PRIVATE + absl::flags_commandlineflag + absl::flags_parse + absl::flags_usage + absl::die_if_null + absl::str_format + protobuf::libprotobuf + ${PROJECT_NAMESPACE}::ortools_proto + ${PROJECT_NAMESPACE}::${WRAPPERS_NAME}) + +include(GNUInstallDirs) +if(APPLE) + set_target_properties(gen_sat_parameters_builder_pybind PROPERTIES INSTALL_RPATH + "@loader_path/../${CMAKE_INSTALL_LIBDIR};@loader_path") +elseif(UNIX) + cmake_path(RELATIVE_PATH CMAKE_INSTALL_FULL_LIBDIR + BASE_DIRECTORY ${CMAKE_INSTALL_FULL_BINDIR} + OUTPUT_VARIABLE libdir_relative_path) + set_target_properties(gen_sat_parameters_builder_pybind PROPERTIES + INSTALL_RPATH "$ORIGIN/${libdir_relative_path}") +endif() + +install(TARGETS gen_sat_parameters_builder_pybind) + +add_custom_command( + OUTPUT sat_parameters_builder_pybind.cc + COMMAND gen_sat_parameters_builder_pybind > sat_parameters_builder_pybind.cc + # DEPENDS ${PROTO_FILE} ${PROTOC_PRG} + COMMENT "Generate C++ protocol buffer for ${PROTO_FILE}" + VERBATIM) + + +pybind11_add_module(sat_parameters_builder_pybind MODULE sat_parameters_builder_pybind.cc) +set_target_properties(sat_parameters_builder_pybind PROPERTIES + LIBRARY_OUTPUT_NAME "sat_parameters_builder") + +# note: macOS is APPLE and also UNIX ! +if(APPLE) + set_target_properties(sat_parameters_builder_pybind PROPERTIES + SUFFIX ".so" + INSTALL_RPATH "@loader_path;@loader_path/../../../${PYTHON_PROJECT}/.libs") +elseif(UNIX) + set_target_properties(sat_parameters_builder_pybind PROPERTIES + INSTALL_RPATH "$ORIGIN:$ORIGIN/../../../${PYTHON_PROJECT}/.libs") +endif() +target_link_libraries(sat_parameters_builder_pybind PRIVATE + ${PROJECT_NAMESPACE}::ortools + protobuf::libprotobuf) + +target_include_directories(sat_parameters_builder_pybind PRIVATE ${protobuf_SOURCE_DIR}) +add_library(${PROJECT_NAMESPACE}::sat_parameters_builder_pybind ALIAS sat_parameters_builder_pybind) + pybind11_add_module(cp_model_helper_pybind11 MODULE cp_model_helper.cc) set_target_properties(cp_model_helper_pybind11 PROPERTIES LIBRARY_OUTPUT_NAME "cp_model_helper") @@ -26,8 +155,8 @@ elseif(UNIX) endif() target_link_libraries(cp_model_helper_pybind11 PRIVATE ${PROJECT_NAMESPACE}::ortools - pybind11_native_proto_caster - protobuf::libprotobuf) + protobuf::libprotobuf + ) target_include_directories(cp_model_helper_pybind11 PRIVATE ${protobuf_SOURCE_DIR}) add_library(${PROJECT_NAMESPACE}::cp_model_helper_pybind11 ALIAS cp_model_helper_pybind11) @@ -38,3 +167,4 @@ if(BUILD_TESTING) add_python_test(FILE_NAME ${FILE_NAME}) endforeach() endif() + diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 435be0f7ba..236eb3cba7 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -45,7 +45,6 @@ Other methods and functions listed are primarily used for developing OR-Tools, rather than for solving specific optimization problems. """ -import copy import threading import time from typing import ( @@ -65,10 +64,9 @@ import warnings import numpy as np import pandas as pd -from ortools.sat import cp_model_pb2 -from ortools.sat import sat_parameters_pb2 +from ortools.sat.python import cp_model_builder as cmb from ortools.sat.python import cp_model_helper as cmh -from ortools.sat.python import cp_model_numbers as cmn +from ortools.sat.python import sat_parameters_builder as spb from ortools.util.python import sorted_interval_list # Import external types. @@ -77,6 +75,7 @@ BoundedLinearExpression = cmh.BoundedLinearExpression FlatFloatExpr = cmh.FlatFloatExpr FlatIntExpr = cmh.FlatIntExpr LinearExpr = cmh.LinearExpr +IntVar = cmh.IntVar NotBooleanVariable = cmh.NotBooleanVariable @@ -90,39 +89,52 @@ INT32_MIN = -(2**31) INT32_MAX = 2**31 - 1 # CpSolver status (exported to avoid importing cp_model_cp2). -UNKNOWN = cp_model_pb2.UNKNOWN -MODEL_INVALID = cp_model_pb2.MODEL_INVALID -FEASIBLE = cp_model_pb2.FEASIBLE -INFEASIBLE = cp_model_pb2.INFEASIBLE -OPTIMAL = cp_model_pb2.OPTIMAL +UNKNOWN = cmb.CpSolverStatus.UNKNOWN +UNKNOWN = cmb.CpSolverStatus.UNKNOWN +MODEL_INVALID = cmb.CpSolverStatus.MODEL_INVALID +FEASIBLE = cmb.CpSolverStatus.FEASIBLE +INFEASIBLE = cmb.CpSolverStatus.INFEASIBLE +OPTIMAL = cmb.CpSolverStatus.OPTIMAL # Variable selection strategy -CHOOSE_FIRST = cp_model_pb2.DecisionStrategyProto.CHOOSE_FIRST -CHOOSE_LOWEST_MIN = cp_model_pb2.DecisionStrategyProto.CHOOSE_LOWEST_MIN -CHOOSE_HIGHEST_MAX = cp_model_pb2.DecisionStrategyProto.CHOOSE_HIGHEST_MAX -CHOOSE_MIN_DOMAIN_SIZE = cp_model_pb2.DecisionStrategyProto.CHOOSE_MIN_DOMAIN_SIZE -CHOOSE_MAX_DOMAIN_SIZE = cp_model_pb2.DecisionStrategyProto.CHOOSE_MAX_DOMAIN_SIZE +CHOOSE_FIRST = cmb.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_FIRST +CHOOSE_LOWEST_MIN = ( + cmb.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_LOWEST_MIN +) +CHOOSE_HIGHEST_MAX = ( + cmb.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_HIGHEST_MAX +) +CHOOSE_MIN_DOMAIN_SIZE = ( + cmb.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_MIN_DOMAIN_SIZE +) +CHOOSE_MAX_DOMAIN_SIZE = ( + cmb.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_MAX_DOMAIN_SIZE +) # Domain reduction strategy -SELECT_MIN_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MIN_VALUE -SELECT_MAX_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MAX_VALUE -SELECT_LOWER_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_LOWER_HALF -SELECT_UPPER_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_UPPER_HALF -SELECT_MEDIAN_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MEDIAN_VALUE -SELECT_RANDOM_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_RANDOM_HALF +SELECT_MIN_VALUE = cmb.DecisionStrategyProto.DomainReductionStrategy.SELECT_MIN_VALUE +SELECT_MAX_VALUE = cmb.DecisionStrategyProto.DomainReductionStrategy.SELECT_MAX_VALUE +SELECT_LOWER_HALF = cmb.DecisionStrategyProto.DomainReductionStrategy.SELECT_LOWER_HALF +SELECT_UPPER_HALF = cmb.DecisionStrategyProto.DomainReductionStrategy.SELECT_UPPER_HALF +SELECT_MEDIAN_VALUE = ( + cmb.DecisionStrategyProto.DomainReductionStrategy.SELECT_MEDIAN_VALUE +) +SELECT_RANDOM_HALF = ( + cmb.DecisionStrategyProto.DomainReductionStrategy.SELECT_RANDOM_HALF +) # Search branching -AUTOMATIC_SEARCH = sat_parameters_pb2.SatParameters.AUTOMATIC_SEARCH -FIXED_SEARCH = sat_parameters_pb2.SatParameters.FIXED_SEARCH -PORTFOLIO_SEARCH = sat_parameters_pb2.SatParameters.PORTFOLIO_SEARCH -LP_SEARCH = sat_parameters_pb2.SatParameters.LP_SEARCH -PSEUDO_COST_SEARCH = sat_parameters_pb2.SatParameters.PSEUDO_COST_SEARCH +AUTOMATIC_SEARCH = spb.SatParameters.SearchBranching.AUTOMATIC_SEARCH +FIXED_SEARCH = spb.SatParameters.SearchBranching.FIXED_SEARCH +PORTFOLIO_SEARCH = spb.SatParameters.SearchBranching.PORTFOLIO_SEARCH +LP_SEARCH = spb.SatParameters.SearchBranching.LP_SEARCH +PSEUDO_COST_SEARCH = spb.SatParameters.SearchBranching.PSEUDO_COST_SEARCH PORTFOLIO_WITH_QUICK_RESTART_SEARCH = ( - sat_parameters_pb2.SatParameters.PORTFOLIO_WITH_QUICK_RESTART_SEARCH + spb.SatParameters.SearchBranching.PORTFOLIO_WITH_QUICK_RESTART_SEARCH ) -HINT_SEARCH = sat_parameters_pb2.SatParameters.HINT_SEARCH -PARTIAL_FIXED_SEARCH = sat_parameters_pb2.SatParameters.PARTIAL_FIXED_SEARCH -RANDOMIZED_SEARCH = sat_parameters_pb2.SatParameters.RANDOMIZED_SEARCH +HINT_SEARCH = spb.SatParameters.SearchBranching.HINT_SEARCH +PARTIAL_FIXED_SEARCH = spb.SatParameters.SearchBranching.PARTIAL_FIXED_SEARCH +RANDOMIZED_SEARCH = spb.SatParameters.SearchBranching.RANDOMIZED_SEARCH # Type aliases IntegralT = Union[int, np.int8, np.uint8, np.int32, np.uint32, np.int64, np.uint64] @@ -170,34 +182,17 @@ ArcT = Tuple[IntegralT, IntegralT, LiteralT] _IndexOrSeries = Union[pd.Index, pd.Series] -def display_bounds(bounds: Sequence[int]) -> str: - """Displays a flattened list of intervals.""" - out = "" - for i in range(0, len(bounds), 2): - if i != 0: - out += ", " - if bounds[i] == bounds[i + 1]: - out += str(bounds[i]) - else: - out += str(bounds[i]) + ".." + str(bounds[i + 1]) - return out - - -def short_name(model: cp_model_pb2.CpModelProto, i: int) -> str: +def short_name(model: cmb.CpModelProto, i: int) -> str: """Returns a short name of an integer variable, or its negation.""" - if i < 0: - return f"not({short_name(model, -i - 1)})" - v = model.variables[i] - if v.name: - return v.name - elif len(v.domain) == 2 and v.domain[0] == v.domain[1]: - return str(v.domain[0]) + if i >= 0: + return str(IntVar(model, i)) else: - return f"[{display_bounds(v.domain)}]" + return f"not({IntVar(model, -i - 1)})" def short_expr_name( - model: cp_model_pb2.CpModelProto, e: cp_model_pb2.LinearExpressionProto + model: cmb.CpModelProto, + e: cmb.LinearExpressionProto, ) -> str: """Pretty-print LinearExpressionProto instances.""" if not e.vars: @@ -221,106 +216,55 @@ def short_expr_name( return str(e) -class IntVar(cmh.BaseIntVar): - """An integer variable. +def arg_is_boolean(x: Any) -> bool: + """Checks if the x is a boolean.""" + if isinstance(x, bool): + return True + if isinstance(x, np.bool_): + return True + return False - An IntVar is an object that can take on any integer value within defined - ranges. Variables appear in constraint like: - x + y >= 5 - AllDifferent([x, y, z]) - - Solving a model is equivalent to finding, for each variable, a single value - from the set of initial values (called the initial domain), such that the - model is feasible, or optimal if you provided an objective function. - """ - - def __init__( - self, - model: cp_model_pb2.CpModelProto, - domain: Union[int, sorted_interval_list.Domain], - is_boolean: bool, - name: Optional[str], - ) -> None: - """See CpModel.new_int_var below.""" - self.__model: cp_model_pb2.CpModelProto = model - # Python do not support multiple __init__ methods. - # This method is only called from the CpModel class. - # We hack the parameter to support the two cases: - # case 1: - # model is a CpModelProto, domain is a Domain, and name is a string. - # case 2: - # model is a CpModelProto, domain is an index (int), and name is None. - if isinstance(domain, IntegralTypes) and name is None: - cmh.BaseIntVar.__init__(self, int(domain), is_boolean) +def rebuild_from_linear_expression_proto( + proto: cmb.LinearExpressionProto, + model_proto: cmb.CpModelProto, +) -> LinearExprT: + """Recreate a LinearExpr from a LinearExpressionProto.""" + num_elements = len(proto.vars) + if num_elements == 0: + return proto.offset + elif num_elements == 1: + var = IntVar(model_proto, proto.vars[0]) + return LinearExpr.affine( + var, proto.coeffs[0], proto.offset + ) # pytype: disable=bad-return-type + else: + variables = [] + for var_index in range(len(proto.vars)): + var = IntVar(model_proto, var_index) + variables.append(var) + if proto.offset != 0: + coeffs = [] + coeffs.extend(proto.coeffs) + coeffs.append(1) + variables.append(proto.offset) + return LinearExpr.weighted_sum(variables, coeffs) else: - cmh.BaseIntVar.__init__(self, len(model.variables), is_boolean) - proto: cp_model_pb2.IntegerVariableProto = self.__model.variables.add() - proto.domain.extend( - cast(sorted_interval_list.Domain, domain).flattened_intervals() - ) - if name is not None: - proto.name = name + return LinearExpr.weighted_sum(variables, proto.coeffs) - def __copy__(self) -> "IntVar": - """Returns a shallowcopy of the variable.""" - return IntVar(self.__model, self.index, self.is_boolean, None) - def __deepcopy__(self, memo: Any) -> "IntVar": - """Returns a deepcopy of the variable.""" - return IntVar( - copy.deepcopy(self.__model, memo), self.index, self.is_boolean, None - ) - - @property - def proto(self) -> cp_model_pb2.IntegerVariableProto: - """Returns the variable protobuf.""" - return self.__model.variables[self.index] - - @property - def model_proto(self) -> cp_model_pb2.CpModelProto: - """Returns the model protobuf.""" - return self.__model - - def is_equal_to(self, other: Any) -> bool: - """Returns true if self == other in the python sense.""" - if not isinstance(other, IntVar): - return False - return self.index == other.index - - def __str__(self) -> str: - if not self.proto.name: - if ( - len(self.proto.domain) == 2 - and self.proto.domain[0] == self.proto.domain[1] - ): - # Special case for constants. - return str(self.proto.domain[0]) - elif self.is_boolean: - return f"BooleanVar({self.__index})" - else: - return f"IntVar({self.__index})" - else: - return self.proto.name - - def __repr__(self) -> str: - return f"{self}({display_bounds(self.proto.domain)})" - - @property - def name(self) -> str: - if not self.proto or not self.proto.name: - return "" - return self.proto.name - - # Pre PEP8 compatibility. - # pylint: disable=invalid-name - def Name(self) -> str: - return self.name - - def Proto(self) -> cp_model_pb2.IntegerVariableProto: - return self.proto - - # pylint: enable=invalid-name +def expand_literals_generator_or_tuple( + args: Union[Tuple[LiteralT, ...], Iterable[LiteralT]], +) -> Union[Iterable[LiteralT], LiteralT]: + if hasattr(args, "__len__"): # Tuple + print("Tuple") + if len(args) != 1: + return args + if isinstance(args[0], (NumberTypes, cmh.Literal)): + return args + # Generator + print(f"Generator {args[0]} {type(args[0])}") + return args[0] class Constraint: @@ -338,24 +282,22 @@ class Constraint: model.add(x + 2 * y == 5).only_enforce_if(b.negated()) """ - def __init__( - self, - cp_model: "CpModel", - ) -> None: - self.__index: int = len(cp_model.proto.constraints) + def __init__(self, cp_model: "CpModel", index: Optional[int] = None) -> None: self.__cp_model: "CpModel" = cp_model - self.__constraint: cp_model_pb2.ConstraintProto = ( + if index is None: + self.__index: int = len(cp_model.proto.constraints) cp_model.proto.constraints.add() - ) + else: + self.__index: int = index @overload - def only_enforce_if(self, boolvar: Iterable[LiteralT]) -> "Constraint": ... + def only_enforce_if(self, literals: Iterable[LiteralT]) -> "Constraint": ... @overload - def only_enforce_if(self, *boolvar: LiteralT) -> "Constraint": ... + def only_enforce_if(self, *literals: LiteralT) -> "Constraint": ... - def only_enforce_if(self, *boolvar) -> "Constraint": - """Adds an enforcement literal to the constraint. + def only_enforce_if(self, *literals) -> "Constraint": + """Adds one or more enforcement literals to the constraint. This method adds one or more literals (that is, a boolean variable or its negation) as enforcement literals. The conjunction of all these literals @@ -366,43 +308,30 @@ class Constraint: BoolOr, BoolAnd, and linear constraints all support enforcement literals. Args: - *boolvar: One or more Boolean literals. + *literals: One or more Boolean literals. Returns: self. """ - for lit in expand_generator_or_tuple(boolvar): - if (cmn.is_boolean(lit) and lit) or ( - isinstance(lit, IntegralTypes) and lit == 1 - ): - # Always true. Do nothing. - pass - elif (cmn.is_boolean(lit) and not lit) or ( - isinstance(lit, IntegralTypes) and lit == 0 - ): - self.__constraint.enforcement_literal.append( - self.__cp_model.new_constant(0).index - ) - else: - self.__constraint.enforcement_literal.append( - cast(cmh.Literal, lit).index - ) + cmh.CpSatHelper.add_enforcement_literals( + self.__index, + self.__cp_model.expand_literals_to_index_list(literals), + self.__cp_model.proto, + ) return self def with_name(self, name: str) -> "Constraint": """Sets the name of the constraint.""" if name: - self.__constraint.name = name + cmh.CpSatHelper.set_ct_name(self.__index, name, self.__cp_model.proto) else: - self.__constraint.ClearField("name") + cmh.CpSatHelper.clear_ct_name(self.__index, self.__cp_model.proto) return self @property def name(self) -> str: """Returns the name of the constraint.""" - if not self.__constraint or not self.__constraint.name: - return "" - return self.__constraint.name + return cmh.CpSatHelper.ct_name(self.__index, self.__cp_model.proto) @property def index(self) -> int: @@ -410,9 +339,15 @@ class Constraint: return self.__index @property - def proto(self) -> cp_model_pb2.ConstraintProto: + def proto(self) -> cmb.ConstraintProto: """Returns the constraint protobuf.""" - return self.__constraint + return self.__cp_model.proto.constraints[self.__index] + + def __str__(self) -> str: + return ( + f"Constraint({self.__index}," + f" {self.__cp_model.proto.constraints[self.__index]})" + ) # Pre PEP8 compatibility. # pylint: disable=invalid-name @@ -425,55 +360,12 @@ class Constraint: def Index(self) -> int: return self.index - def Proto(self) -> cp_model_pb2.ConstraintProto: + def Proto(self) -> cmb.ConstraintProto: return self.proto # pylint: enable=invalid-name -class VariableList: - """Stores all integer variables of the model.""" - - def __init__(self) -> None: - self.__var_list: list[IntVar] = [] - - def append(self, var: IntVar) -> None: - assert var.index == len(self.__var_list) - self.__var_list.append(var) - - def get(self, index: int) -> IntVar: - if index < 0 or index >= len(self.__var_list): - raise ValueError("Index out of bounds.") - return self.__var_list[index] - - def rebuild_expr( - self, - proto: cp_model_pb2.LinearExpressionProto, - ) -> LinearExprT: - """Recreate a LinearExpr from a LinearExpressionProto.""" - num_elements = len(proto.vars) - if num_elements == 0: - return proto.offset - elif num_elements == 1: - var = self.get(proto.vars[0]) - return LinearExpr.affine( - var, proto.coeffs[0], proto.offset - ) # pytype: disable=bad-return-type - else: - variables = [] - for var_index in range(len(proto.vars)): - var = self.get(var_index) - variables.append(var) - if proto.offset != 0: - coeffs = [] - coeffs.extend(proto.coeffs) - coeffs.append(1) - variables.append(proto.offset) - return LinearExpr.weighted_sum(variables, coeffs) - else: - return LinearExpr.weighted_sum(variables, proto.coeffs) - - class IntervalVar: """Represents an Interval variable. @@ -497,18 +389,16 @@ class IntervalVar: def __init__( self, - model: cp_model_pb2.CpModelProto, - var_list: VariableList, - start: Union[cp_model_pb2.LinearExpressionProto, int], - size: Optional[cp_model_pb2.LinearExpressionProto], - end: Optional[cp_model_pb2.LinearExpressionProto], + model: cmb.CpModelProto, + start: Union[cmb.LinearExpressionProto, int], + size: Optional[cmb.LinearExpressionProto], + end: Optional[cmb.LinearExpressionProto], is_present_index: Optional[int], name: Optional[str], ) -> None: - self.__model: cp_model_pb2.CpModelProto = model - self.__var_list: VariableList = var_list + self.__model: cmb.CpModelProto = model self.__index: int - self.__ct: cp_model_pb2.ConstraintProto + self.__ct: cmb.ConstraintProto # As with the IntVar::__init__ method, we hack the __init__ method to # support two use cases: # case 1: called when creating a new interval variable. @@ -530,13 +420,13 @@ class IntervalVar: self.__ct = self.__model.constraints.add() if start is None: raise TypeError("start is not defined") - self.__ct.interval.start.CopyFrom(start) + self.__ct.interval.start.copy_from(start) if size is None: raise TypeError("size is not defined") - self.__ct.interval.size.CopyFrom(size) + self.__ct.interval.size.copy_from(size) if end is None: raise TypeError("end is not defined") - self.__ct.interval.end.CopyFrom(end) + self.__ct.interval.end.copy_from(end) if is_present_index is not None: self.__ct.enforcement_literal.append(is_present_index) if name: @@ -548,12 +438,12 @@ class IntervalVar: return self.__index @property - def proto(self) -> cp_model_pb2.ConstraintProto: + def proto(self) -> cmb.ConstraintProto: """Returns the interval protobuf.""" return self.__model.constraints[self.__index] @property - def model_proto(self) -> cp_model_pb2.CpModelProto: + def model_proto(self) -> cmb.CpModelProto: """Returns the model protobuf.""" return self.__model @@ -585,13 +475,19 @@ class IntervalVar: return self.proto.name def start_expr(self) -> LinearExprT: - return self.__var_list.rebuild_expr(self.proto.interval.start) + return rebuild_from_linear_expression_proto( + self.proto.interval.start, self.__model + ) def size_expr(self) -> LinearExprT: - return self.__var_list.rebuild_expr(self.proto.interval.size) + return rebuild_from_linear_expression_proto( + self.proto.interval.size, self.__model + ) def end_expr(self) -> LinearExprT: - return self.__var_list.rebuild_expr(self.proto.interval.end) + return rebuild_from_linear_expression_proto( + self.proto.interval.end, self.__model + ) # Pre PEP8 compatibility. # pylint: disable=invalid-name @@ -601,7 +497,7 @@ class IntervalVar: def Index(self) -> int: return self.index - def Proto(self) -> cp_model_pb2.ConstraintProto: + def Proto(self) -> cmb.ConstraintProto: return self.proto StartExpr = start_expr @@ -642,14 +538,13 @@ class CpModel: Methods beginning with: - * ```New``` create integer, boolean, or interval variables. - * ```add``` create new constraints and add them to the model. + * ```new_``` create integer, boolean, or interval variables. + * ```add_``` create new constraints and add them to the model. """ def __init__(self) -> None: - self.__model: cp_model_pb2.CpModelProto = cp_model_pb2.CpModelProto() + self.__model: cmb.CpModelProto = cmb.CpModelProto() self.__constant_map: Dict[IntegralT, int] = {} - self.__var_list: VariableList = VariableList() # Naming. @property @@ -665,21 +560,6 @@ class CpModel: self.__model.name = name # Integer variable. - - def _append_int_var(self, var: IntVar) -> IntVar: - """Appends an integer variable to the list of variables.""" - self.__var_list.append(var) - return var - - def _get_int_var(self, index: int) -> IntVar: - return self.__var_list.get(index) - - def rebuild_from_linear_expression_proto( - self, - proto: cp_model_pb2.LinearExpressionProto, - ) -> LinearExpr: - return self.__var_list.rebuild_expr(proto) - def new_int_var(self, lb: IntegralT, ub: IntegralT, name: str) -> IntVar: """Create an integer variable with domain [lb, ub]. @@ -695,14 +575,10 @@ class CpModel: Returns: a variable whose domain is [lb, ub]. """ - domain_is_boolean = lb >= 0 and ub <= 1 - return self._append_int_var( - IntVar( - self.__model, - sorted_interval_list.Domain(lb, ub), - domain_is_boolean, - name, - ) + return ( + IntVar(self.__model) + .with_name(name) + .with_domain(sorted_interval_list.Domain(lb, ub)) ) def new_int_var_from_domain( @@ -721,21 +597,20 @@ class CpModel: Returns: a variable whose domain is the given domain. """ - domain_is_boolean = domain.min() >= 0 and domain.max() <= 1 - return self._append_int_var( - IntVar(self.__model, domain, domain_is_boolean, name) - ) + return IntVar(self.__model).with_name(name).with_domain(domain) def new_bool_var(self, name: str) -> IntVar: """Creates a 0-1 variable with the given name.""" - return self._append_int_var( - IntVar(self.__model, sorted_interval_list.Domain(0, 1), True, name) + return ( + IntVar(self.__model) + .with_name(name) + .with_domain(sorted_interval_list.Domain(0, 1)) ) def new_constant(self, value: IntegralT) -> IntVar: """Declares a constant integer.""" index: int = self.get_or_make_index_from_constant(value) - return self._get_int_var(index) + return IntVar(self.__model, index) def new_int_var_series( self, @@ -790,15 +665,10 @@ class CpModel: index=index, data=[ # pylint: disable=g-complex-comprehension - self._append_int_var( - IntVar( - model=self.__model, - name=f"{name}[{i}]", - domain=sorted_interval_list.Domain( - lower_bounds[i], upper_bounds[i] - ), - is_boolean=lower_bounds[i] >= 0 and upper_bounds[i] <= 1, - ) + IntVar(self.__model) + .with_name(f"{name}[{i}]") + .with_domain( + sorted_interval_list.Domain(lower_bounds[i], upper_bounds[i]) ) for i in index ], @@ -830,14 +700,9 @@ class CpModel: index=index, data=[ # pylint: disable=g-complex-comprehension - self._append_int_var( - IntVar( - model=self.__model, - name=f"{name}[{i}]", - domain=sorted_interval_list.Domain(0, 1), - is_boolean=True, - ) - ) + IntVar(self.__model) + .with_name(f"{name}[{i}]") + .with_domain(sorted_interval_list.Domain(0, 1)) for i in index ], ) @@ -889,21 +754,15 @@ class CpModel: TypeError: If the `ct` is not a `BoundedLinearExpression` or a Boolean. """ if isinstance(ct, BoundedLinearExpression): - result = Constraint(self) - model_ct = self.__model.constraints[result.index] - for var in ct.vars: - model_ct.linear.vars.append(var.index) - model_ct.linear.coeffs.extend(ct.coeffs) - model_ct.linear.domain.extend( - [ - cmn.capped_subtraction(x, ct.offset) - for x in ct.bounds.flattened_intervals() - ] + return Constraint( + self, + cmh.CpSatHelper.add_bounded_linear_expression_to_model( + ct, self.__model + ), ) - return result - if ct and cmn.is_boolean(ct): + if ct and arg_is_boolean(ct): return self.add_bool_or([True]) - if not ct and cmn.is_boolean(ct): + if not ct and arg_is_boolean(ct): return self.add_bool_or([]) # Evaluate to false. raise TypeError(f"not supported: CpModel.add({type(ct).__name__!r})") @@ -928,7 +787,7 @@ class CpModel: """ ct = Constraint(self) model_ct = self.__model.constraints[ct.index] - expanded = expand_generator_or_tuple(expressions) + expanded = expand_exprs_generator_or_tuple(expressions) model_ct.all_diff.exprs.extend( self.parse_linear_expression(x) for x in expanded ) @@ -962,11 +821,11 @@ class CpModel: ct = Constraint(self) model_ct = self.__model.constraints[ct.index] - model_ct.element.linear_index.CopyFrom(self.parse_linear_expression(index)) + model_ct.element.linear_index.copy_from(self.parse_linear_expression(index)) model_ct.element.exprs.extend( [self.parse_linear_expression(e) for e in expressions] ) - model_ct.element.linear_target.CopyFrom(self.parse_linear_expression(target)) + model_ct.element.linear_target.copy_from(self.parse_linear_expression(target)) return ct def add_circuit(self, arcs: Sequence[ArcT]) -> Constraint: @@ -1417,15 +1276,9 @@ class CpModel: def add_bool_or(self, *literals): """Adds `Or(literals) == true`: sum(literals) >= 1.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_or.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_bool_or(lits, self.__model) + return Constraint(self, index) @overload def add_at_least_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1437,23 +1290,19 @@ class CpModel: """Same as `add_bool_or`: `sum(literals) >= 1`.""" return self.add_bool_or(*literals) - @overload - def add_at_most_one(self, literals: Iterable[LiteralT]) -> Constraint: ... + # @overload + # def add_at_most_one(self, literals: Iterable[LiteralT]) -> Constraint: + # ... - @overload - def add_at_most_one(self, *literals: LiteralT) -> Constraint: ... + # @overload + # def add_at_most_one(self, *literals: LiteralT) -> Constraint: + # ... - def add_at_most_one(self, *literals): + def add_at_most_one(self, *literals) -> Constraint: """Adds `AtMostOne(literals)`: `sum(literals) <= 1`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.at_most_one.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_at_most_one(lits, self.__model) + return Constraint(self, index) @overload def add_exactly_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1463,15 +1312,9 @@ class CpModel: def add_exactly_one(self, *literals): """Adds `ExactlyOne(literals)`: `sum(literals) == 1`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.exactly_one.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_exactly_one(lits, self.__model) + return Constraint(self, index) @overload def add_bool_and(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1481,15 +1324,9 @@ class CpModel: def add_bool_and(self, *literals): """Adds `And(literals) == true`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_and.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_bool_and(lits, self.__model) + return Constraint(self, index) @overload def add_bool_xor(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1509,15 +1346,9 @@ class CpModel: Returns: An `Constraint` object. """ - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_xor.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_bool_xor(lits, self.__model) + return Constraint(self, index) def add_min_equality( self, target: LinearExprT, exprs: Iterable[LinearExprT] @@ -1528,7 +1359,7 @@ class CpModel: model_ct.lin_max.exprs.extend( [self.parse_linear_expression(x, True) for x in exprs] ) - model_ct.lin_max.target.CopyFrom(self.parse_linear_expression(target, True)) + model_ct.lin_max.target.copy_from(self.parse_linear_expression(target, True)) return ct def add_max_equality( @@ -1538,7 +1369,7 @@ class CpModel: ct = Constraint(self) model_ct = self.__model.constraints[ct.index] model_ct.lin_max.exprs.extend([self.parse_linear_expression(x) for x in exprs]) - model_ct.lin_max.target.CopyFrom(self.parse_linear_expression(target)) + model_ct.lin_max.target.copy_from(self.parse_linear_expression(target)) return ct def add_division_equality( @@ -1549,7 +1380,7 @@ class CpModel: model_ct = self.__model.constraints[ct.index] model_ct.int_div.exprs.append(self.parse_linear_expression(num)) model_ct.int_div.exprs.append(self.parse_linear_expression(denom)) - model_ct.int_div.target.CopyFrom(self.parse_linear_expression(target)) + model_ct.int_div.target.copy_from(self.parse_linear_expression(target)) return ct def add_abs_equality(self, target: LinearExprT, expr: LinearExprT) -> Constraint: @@ -1558,7 +1389,7 @@ class CpModel: model_ct = self.__model.constraints[ct.index] model_ct.lin_max.exprs.append(self.parse_linear_expression(expr)) model_ct.lin_max.exprs.append(self.parse_linear_expression(expr, True)) - model_ct.lin_max.target.CopyFrom(self.parse_linear_expression(target)) + model_ct.lin_max.target.copy_from(self.parse_linear_expression(target)) return ct def add_modulo_equality( @@ -1587,7 +1418,7 @@ class CpModel: model_ct = self.__model.constraints[ct.index] model_ct.int_mod.exprs.append(self.parse_linear_expression(expr)) model_ct.int_mod.exprs.append(self.parse_linear_expression(mod)) - model_ct.int_mod.target.CopyFrom(self.parse_linear_expression(target)) + model_ct.int_mod.target.copy_from(self.parse_linear_expression(target)) return ct def add_multiplication_equality( @@ -1601,10 +1432,10 @@ class CpModel: model_ct.int_prod.exprs.extend( [ self.parse_linear_expression(expr) - for expr in expand_generator_or_tuple(expressions) + for expr in expand_exprs_generator_or_tuple(expressions) ] ) - model_ct.int_prod.target.CopyFrom(self.parse_linear_expression(target)) + model_ct.int_prod.target.copy_from(self.parse_linear_expression(target)) return ct # Scheduling support @@ -1646,7 +1477,6 @@ class CpModel: ) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -1731,7 +1561,6 @@ class CpModel: ) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -1833,7 +1662,6 @@ class CpModel: ) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -1932,7 +1760,6 @@ class CpModel: is_present_index = self.get_or_make_boolean_index(is_present) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -2074,32 +1901,32 @@ class CpModel: ) for d in demands: model_ct.cumulative.demands.append(self.parse_linear_expression(d)) - model_ct.cumulative.capacity.CopyFrom(self.parse_linear_expression(capacity)) + model_ct.cumulative.capacity.copy_from(self.parse_linear_expression(capacity)) return cumulative # Support for model cloning. def clone(self) -> "CpModel": """Reset the model, and creates a new one from a CpModelProto instance.""" clone = CpModel() - clone.proto.CopyFrom(self.proto) - clone.rebuild_var_and_constant_map() + clone.proto.copy_from(self.proto) + clone.rebuild_constant_map() return clone - def rebuild_var_and_constant_map(self): + def rebuild_constant_map(self): """Internal method used during model cloning.""" for i, var in enumerate(self.__model.variables): if len(var.domain) == 2 and var.domain[0] == var.domain[1]: self.__constant_map[var.domain[0]] = i - is_boolean = ( - len(var.domain) == 2 and var.domain[0] >= 0 and var.domain[1] <= 1 - ) - self.__var_list.append(IntVar(self.__model, i, is_boolean, None)) def get_bool_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created Boolean variable from its index.""" - result = self._get_int_var(index) - if not result.is_boolean: + if index < 0 or index >= len(self.__model.variables): raise ValueError( + f"get_bool_var_from_proto_index: out of bound index {index}" + ) + result = IntVar(self.__model, index) + if not result.is_boolean: + raise TypeError( f"get_bool_var_from_proto_index: index {index} does not reference a" " boolean variable" ) @@ -2107,7 +1934,11 @@ class CpModel: def get_int_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created integer variable from its index.""" - return self._get_int_var(index) + if index < 0 or index >= len(self.__model.variables): + raise ValueError( + f"get_int_var_from_proto_index: out of bound index {index}" + ) + return IntVar(self.__model, index) def get_interval_var_from_proto_index(self, index: int) -> IntervalVar: """Returns an already created interval variable from its index.""" @@ -2116,13 +1947,13 @@ class CpModel: f"get_interval_var_from_proto_index: out of bound index {index}" ) ct = self.__model.constraints[index] - if not ct.HasField("interval"): + if not ct.has_interval(): raise ValueError( f"get_interval_var_from_proto_index: index {index} does not" " reference an" + " interval variable" ) - return IntervalVar(self.__model, self.__var_list, index, None, None, None, None) + return IntervalVar(self.__model, index, None, None, None, None) # Helpers. @@ -2130,7 +1961,7 @@ class CpModel: return str(self.__model) @property - def proto(self) -> cp_model_pb2.CpModelProto: + def proto(self) -> cmb.CpModelProto: """Returns the underlying CpModelProto.""" return self.__model @@ -2160,9 +1991,11 @@ class CpModel: return self.get_or_make_index_from_constant(1) if arg == ~int(True): return self.get_or_make_index_from_constant(0) - arg = cmn.assert_is_zero_or_one(arg) - return self.get_or_make_index_from_constant(arg) - if cmn.is_boolean(arg): + arg_as_int: int = int(arg) + if arg_as_int < 0 or arg_as_int > 1: + raise TypeError(f"Not a boolean: {arg}") + return self.get_or_make_index_from_constant(arg_as_int) + if arg_is_boolean(arg): return self.get_or_make_index_from_constant(int(arg)) raise TypeError( "not supported:" f" model.get_or_make_boolean_index({type(arg).__name__!r})" @@ -2184,11 +2017,9 @@ class CpModel: def parse_linear_expression( self, linear_expr: LinearExprT, negate: bool = False - ) -> cp_model_pb2.LinearExpressionProto: + ) -> cmb.LinearExpressionProto: """Returns a LinearExpressionProto built from a LinearExpr instance.""" - result: cp_model_pb2.LinearExpressionProto = ( - cp_model_pb2.LinearExpressionProto() - ) + result: cmb.LinearExpressionProto = cmb.LinearExpressionProto() mult = -1 if negate else 1 if isinstance(linear_expr, IntegralTypes): result.offset = int(linear_expr) * mult @@ -2244,19 +2075,19 @@ class CpModel: self._set_objective(obj, minimize=False) def has_objective(self) -> bool: - return self.__model.HasField("objective") or self.__model.HasField( - "floating_point_objective" + return ( + self.__model.has_objective() or self.__model.has_floating_point_objective() ) def clear_objective(self): - self.__model.ClearField("objective") - self.__model.ClearField("floating_point_objective") + self.__model.clear_objective() + self.__model.clear_floating_point_objective() def add_decision_strategy( self, variables: Sequence[IntVar], - var_strategy: cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy, - domain_strategy: cp_model_pb2.DecisionStrategyProto.DomainReductionStrategy, + var_strategy: cmb.DecisionStrategyProto.VariableSelectionStrategy, + domain_strategy: cmb.DecisionStrategyProto.DomainReductionStrategy, ) -> None: """Adds a search strategy to the model. @@ -2269,9 +2100,7 @@ class CpModel: solve() will fail. """ - strategy: cp_model_pb2.DecisionStrategyProto = ( - self.__model.search_strategy.add() - ) + strategy: cmb.DecisionStrategyProto = self.__model.search_strategy.add() for v in variables: expr = strategy.exprs.add() if v.index >= 0: @@ -2308,11 +2137,11 @@ class CpModel: def remove_all_names(self) -> None: """Removes all names from the model.""" - self.__model.ClearField("name") + self.__model.clear_name() for v in self.__model.variables: - v.ClearField("name") + v.clear_name() for c in self.__model.constraints: - c.ClearField("name") + c.clear_name() @overload def add_hint(self, var: IntVar, value: int) -> None: ... @@ -2331,7 +2160,7 @@ class CpModel: def clear_hints(self): """Removes any solution hint from the model.""" - self.__model.ClearField("solution_hint") + self.__model.clear_solution_hint() def add_assumption(self, lit: LiteralT) -> None: """Adds the literal to the model as assumptions.""" @@ -2344,7 +2173,7 @@ class CpModel: def clear_assumptions(self) -> None: """Removes all assumptions from the model.""" - self.__model.ClearField("assumptions") + self.__model.clear_assumptions() # Helpers. def assert_is_boolean_variable(self, x: LiteralT) -> None: @@ -2359,6 +2188,27 @@ class CpModel: f"TypeError: {type(x).__name__!r} is not a boolean variable" ) + def expand_literals_enerator_or_tuple( + self, args: Union[Tuple[LiteralT, ...], Iterable[LiteralT]] + ): + if hasattr(args, "__len__"): # Tuple + if len(args) != 1: + return args + if isinstance(args[0], (NumberTypes, cmh.Literal)): + return args + # Generator + return args[0] + + def expand_literals_to_index_list( + self, + literals: Union[Tuple[LiteralT, ...], Iterable[LiteralT]], + ) -> list[int]: + """Expands a tuple or generator of literals to a list of indices.""" + return [ + self.get_or_make_boolean_index(lit) + for lit in self.expand_literals_enerator_or_tuple(literals) + ] + # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -2368,7 +2218,7 @@ class CpModel: def SetName(self, name: str) -> None: self.name = name - def Proto(self) -> cp_model_pb2.CpModelProto: + def Proto(self) -> cmb.CpModelProto: return self.proto NewIntVar = new_int_var @@ -2434,19 +2284,9 @@ class CpModel: # pylint: enable=invalid-name -@overload -def expand_generator_or_tuple( - args: Union[Tuple[LiteralT, ...], Iterable[LiteralT]], -) -> Union[Iterable[LiteralT], LiteralT]: ... - - -@overload -def expand_generator_or_tuple( +def expand_exprs_generator_or_tuple( args: Union[Tuple[LinearExprT, ...], Iterable[LinearExprT]], -) -> Union[Iterable[LinearExprT], LinearExprT]: ... - - -def expand_generator_or_tuple(args): +) -> Union[Iterable[LinearExprT], LinearExprT]: if hasattr(args, "__len__"): # Tuple if len(args) != 1: return args @@ -2469,9 +2309,7 @@ class CpSolver: def __init__(self) -> None: self.__response_wrapper: Optional[cmh.ResponseWrapper] = None - self.parameters: sat_parameters_pb2.SatParameters = ( - sat_parameters_pb2.SatParameters() - ) + self.parameters: spb.SatParameters = spb.SatParameters() self.log_callback: Optional[Callable[[str], None]] = None self.best_bound_callback: Optional[Callable[[float], None]] = None self.__solve_wrapper: Optional[cmh.SolveWrapper] = None @@ -2481,7 +2319,7 @@ class CpSolver: self, model: CpModel, solution_callback: Optional["CpSolverSolutionCallback"] = None, - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmb.CpSolverStatus: """Solves a problem and passes each solution to the callback if not null.""" with self.__lock: self.__solve_wrapper = cmh.SolveWrapper() @@ -2628,6 +2466,21 @@ class CpSolver: """Returns the number of search branches explored by the solver.""" return self._checked_response.num_branches() + @property + def num_boolean_propagations(self) -> int: + """Returns the number of Boolean propagations done by the solver.""" + return self._checked_response.num_boolean_propagations() + + @property + def num_integer_propagations(self) -> int: + """Returns the number of integer propagations done by the solver.""" + return self._checked_response.num_integer_propagations() + + @property + def deterministic_time(self) -> float: + """Returns the deterministic time in seconds since the creation of the solver.""" + return self._checked_response.deterministic_time() + @property def wall_time(self) -> float: """Returns the wall time in seconds since the creation of the solver.""" @@ -2639,7 +2492,7 @@ class CpSolver: return self._checked_response.user_time() @property - def response_proto(self) -> cp_model_pb2.CpSolverResponse: + def response_proto(self) -> cmb.CpSolverResponse: """Returns the response object.""" return self._checked_response.response() @@ -2655,7 +2508,7 @@ class CpSolver: """Returns the name of the status returned by solve().""" if status is None: status = self._checked_response.status() - return cp_model_pb2.CpSolverStatus.Name(status) + return status.name def solution_info(self) -> str: """Returns some information on the solve process. @@ -2699,7 +2552,7 @@ class CpSolver: def ObjectiveValue(self) -> float: return self.objective_value - def ResponseProto(self) -> cp_model_pb2.CpSolverResponse: + def ResponseProto(self) -> cmb.CpSolverResponse: return self.response_proto def ResponseStats(self) -> str: @@ -2709,7 +2562,7 @@ class CpSolver: self, model: CpModel, solution_callback: Optional["CpSolverSolutionCallback"] = None, - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmb.CpSolverStatus: return self.solve(model, solution_callback) def SolutionInfo(self) -> str: @@ -2738,7 +2591,7 @@ class CpSolver: def SolveWithSolutionCallback( self, model: CpModel, callback: "CpSolverSolutionCallback" - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmb.CpSolverStatus: """DEPRECATED Use solve() with the callback argument.""" warnings.warn( "solve_with_solution_callback is deprecated; use solve() with" @@ -2749,7 +2602,7 @@ class CpSolver: def SearchForAllSolutions( self, model: CpModel, callback: "CpSolverSolutionCallback" - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmb.CpSolverStatus: """DEPRECATED Use solve() with the right parameter. Search for all solutions of a satisfiability problem. @@ -2783,7 +2636,7 @@ class CpSolver: enumerate_all = self.parameters.enumerate_all_solutions self.parameters.enumerate_all_solutions = True - status: cp_model_pb2.CpSolverStatus = self.solve(model, callback) + status: cmb.CpSolverStatus = self.solve(model, callback) # Restore parameter. self.parameters.enumerate_all_solutions = enumerate_all @@ -2944,7 +2797,7 @@ class CpSolverSolutionCallback(cmh.SolutionCallback): return self.UserTime() @property - def response_proto(self) -> cp_model_pb2.CpSolverResponse: + def response_proto(self) -> cmb.CpSolverResponse: """Returns the response object.""" if not self.has_response(): raise RuntimeError("solve() has not been called.") diff --git a/ortools/sat/python/cp_model_builder_test.py b/ortools/sat/python/cp_model_builder_test.py new file mode 100644 index 0000000000..05293b37ca --- /dev/null +++ b/ortools/sat/python/cp_model_builder_test.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from ortools.sat import cp_model_pb2 +from ortools.sat.python import cp_model_builder + + +class CpModelBuilderTest(absltest.TestCase): + + def test_basic(self): + model_proto = cp_model_builder.CpModelProto() + + # Singular message. + objective = model_proto.objective + + # Singular int. + self.assertEqual(objective.offset, 0) + objective.offset = 123 + self.assertEqual(objective.offset, 123) + + # Set a message. + new_obj = cp_model_builder.CpObjectiveProto() + new_obj.offset = 456 + model_proto.objective = new_obj + self.assertEqual(objective.offset, 456) + + # Large int. + objective.offset = 500000000000 + self.assertEqual(objective.offset, 500000000000) + + # Repeated message. + my_var = model_proto.variables.add() + + # Singular string. + self.assertEqual(my_var.name, "") + my_var.name = "my_var" + self.assertEqual(my_var.name, "my_var") + my_var.domain.extend([0, 1]) + domain = list(my_var.domain) + self.assertLen(domain, 2) + self.assertEqual(domain[0], 0) + self.assertEqual(domain[1], 1) + + # Repeated int. + objective.vars.append(0) + self.assertLen(objective.vars, 1) + self.assertEqual(objective.vars[0], 0) + objective.vars[0] = 42 + self.assertEqual(objective.vars[0], 42) + + # Singular enum + search_strategy = model_proto.search_strategy.add() + self.assertEqual( + search_strategy.variable_selection_strategy, + cp_model_builder.DecisionStrategyProto.CHOOSE_FIRST, + ) + search_strategy.variable_selection_strategy = ( + cp_model_builder.DecisionStrategyProto.CHOOSE_LOWEST_MIN + ) + self.assertEqual( + search_strategy.variable_selection_strategy, + cp_model_pb2.DecisionStrategyProto.CHOOSE_LOWEST_MIN, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index e077cc3141..3fe72f69cc 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -30,6 +30,7 @@ #include "ortools/sat/python/linear_expr.h" #include "ortools/sat/python/linear_expr_doc.h" #include "ortools/sat/swig_helper.h" +#include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" #include "pybind11/attr.h" #include "pybind11/cast.h" @@ -39,7 +40,6 @@ #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" -#include "pybind11_protobuf/native_proto_caster.h" namespace py = pybind11; @@ -83,28 +83,6 @@ class PySolutionCallback : public SolutionCallback { } }; -// A trampoline class to override the __str__ and __repr__ methods. -class PyBaseIntVar : public BaseIntVar { - public: - using BaseIntVar::BaseIntVar; /* Inherit constructors */ - - std::string ToString() const override { - PYBIND11_OVERRIDE_PURE_NAME(std::string, // Return type (ret_type) - BaseIntVar, // Parent class (cname) - "__str__", // Name of method in Python (name) - ToString, // Name of function in C++ (fn) - ); - } - - std::string DebugString() const override { - PYBIND11_OVERRIDE_PURE_NAME(std::string, // Return type (ret_type) - BaseIntVar, // Parent class (cname) - "__repr__", // Name of method in Python (name) - DebugString, // Name of function in C++ (fn) - ); - } -}; - // A class to wrap a C++ CpSolverResponse in a Python object, avoid the proto // conversion back to python. class ResponseWrapper { @@ -434,8 +412,96 @@ std::shared_ptr WeightedSumArguments(py::sequence expressions, } } +int AddBoundedLinearExpressionToModel( + BoundedLinearExpression* ble, std::shared_ptr model_proto) { + const int index = model_proto->constraints_size(); + ConstraintProto* ct = model_proto->add_constraints(); + for (const auto& var : ble->vars()) { + ct->mutable_linear()->add_vars(var->index()); + } + for (const int64_t coeff : ble->coeffs()) { + ct->mutable_linear()->add_coeffs(coeff); + } + const int64_t offset = ble->offset(); + const Domain& bounds = ble->bounds(); + for (const int64_t bound : bounds.FlattenedIntervals()) { + if (bound == std::numeric_limits::min() || + bound == std::numeric_limits::max()) { + ct->mutable_linear()->add_domain(bound); + } else { + ct->mutable_linear()->add_domain(CapSub(bound, offset)); + } + } + return index; +} + +int AddBoolOr(const std::vector& literals, + std::shared_ptr model_proto) { + const int index = model_proto->constraints_size(); + ConstraintProto* ct = model_proto->add_constraints(); + ct->mutable_bool_or()->mutable_literals()->Add(literals.begin(), + literals.end()); + return index; +} + +int AddBoolAnd(const std::vector& literals, + std::shared_ptr model_proto) { + const int index = model_proto->constraints_size(); + ConstraintProto* ct = model_proto->add_constraints(); + ct->mutable_bool_and()->mutable_literals()->Add(literals.begin(), + literals.end()); + return index; +} + +int AddBoolXOr(const std::vector& literals, + std::shared_ptr model_proto) { + const int index = model_proto->constraints_size(); + ConstraintProto* ct = model_proto->add_constraints(); + ct->mutable_bool_xor()->mutable_literals()->Add(literals.begin(), + literals.end()); + return index; +} + +int AddAtMostOne(const std::vector& literals, + std::shared_ptr model_proto) { + const int index = model_proto->constraints_size(); + ConstraintProto* ct = model_proto->add_constraints(); + ct->mutable_at_most_one()->mutable_literals()->Add(literals.begin(), + literals.end()); + return index; +} + +int AddExactlyOne(const std::vector& literals, + std::shared_ptr model_proto) { + const int index = model_proto->constraints_size(); + ConstraintProto* ct = model_proto->add_constraints(); + ct->mutable_exactly_one()->mutable_literals()->Add(literals.begin(), + literals.end()); + return index; +} + +void AddEnforcementLiterals(int index, const std::vector& literals, + std::shared_ptr model_proto) { + ConstraintProto* ct = model_proto->mutable_constraints(index); + ct->mutable_enforcement_literal()->Add(literals.begin(), literals.end()); +} + +void SetCtName(int index, const std::string& name, + std::shared_ptr model_proto) { + model_proto->mutable_constraints(index)->set_name(name); +} + +std::string GetCtName(int index, std::shared_ptr model_proto) { + return model_proto->constraints(index).name(); +} + +void ClearCtName(int index, std::shared_ptr model_proto) { + model_proto->mutable_constraints(index)->clear_name(); +} + PYBIND11_MODULE(cp_model_helper, m) { - pybind11_protobuf::ImportNativeProtoCasters(); + py::module::import("ortools.sat.python.cp_model_builder"); + py::module::import("ortools.sat.python.sat_parameters_builder"); py::module::import("ortools.util.python.sorted_interval_list"); // We keep the CamelCase name for the SolutionCallback class to be @@ -577,28 +643,35 @@ PYBIND11_MODULE(cp_model_helper, m) { solve_wrapper->AddBestBoundCallback(safe_best_bound_callback); }, py::arg("best_bound_callback").none(false)) - .def("set_parameters", &SolveWrapper::SetParameters, - py::arg("parameters")) - .def("solve", - [](ExtSolveWrapper* solve_wrapper, - const CpModelProto& model_proto) -> CpSolverResponse { - const auto result = [&]() -> CpSolverResponse { - ::py::gil_scoped_release release; - return solve_wrapper->Solve(model_proto); - }(); - if (solve_wrapper->local_error_already_set_.has_value()) { - solve_wrapper->local_error_already_set_->restore(); - solve_wrapper->local_error_already_set_.reset(); - throw py::error_already_set(); - } - return result; - }) + .def( + "set_parameters", + [](ExtSolveWrapper* solve_wrapper, + std::shared_ptr parameters) { + solve_wrapper->SetParameters(*parameters); + }, + py::arg("parameters").none(false)) + .def( + "solve", + [](ExtSolveWrapper* solve_wrapper, + std::shared_ptr model_proto) -> CpSolverResponse { + const auto result = [=]() -> CpSolverResponse { + ::py::gil_scoped_release release; + return solve_wrapper->Solve(*model_proto); + }(); + if (solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_->restore(); + solve_wrapper->local_error_already_set_.reset(); + throw py::error_already_set(); + } + return result; + }, + py::arg("model_proto").none(false)) .def("solve_and_return_response_wrapper", [](ExtSolveWrapper* solve_wrapper, - const CpModelProto& model_proto) -> ResponseWrapper { - const auto result = [&]() -> ResponseWrapper { + std::shared_ptr model_proto) -> ResponseWrapper { + const auto result = [=]() -> ResponseWrapper { ::py::gil_scoped_release release; - return ResponseWrapper(solve_wrapper->Solve(model_proto)); + return ResponseWrapper(solve_wrapper->Solve(*model_proto)); }(); if (solve_wrapper->local_error_already_set_.has_value()) { solve_wrapper->local_error_already_set_->restore(); @@ -619,7 +692,29 @@ PYBIND11_MODULE(cp_model_helper, m) { .def_static("variable_domain", &CpSatHelper::VariableDomain, py::arg("variable_proto")) .def_static("write_model_to_file", &CpSatHelper::WriteModelToFile, - py::arg("model_proto"), py::arg("filename")); + py::arg("model_proto"), py::arg("filename")) + .def_static("set_ct_name", &SetCtName, py::arg("index"), py::arg("name"), + py::arg("model_proto")) + .def_static("ct_name", &GetCtName, py::arg("index"), + py::arg("model_proto")) + .def_static("clear_ct_name", &ClearCtName, py::arg("index"), + py::arg("model_proto")) + .def_static("add_bool_or", &AddBoolOr, py::arg("literals"), + py::arg("model_proto").none(false)) + .def_static("add_bool_and", &AddBoolAnd, py::arg("literals"), + py::arg("model_proto").none(false)) + .def_static("add_bool_xor", &AddBoolXOr, py::arg("literals"), + py::arg("model_proto").none(false)) + .def_static("add_at_most_one", &AddAtMostOne, py::arg("literals"), + py::arg("model_proto").none(false)) + .def_static("add_exactly_one", &AddExactlyOne, py::arg("literals"), + py::arg("model_proto").none(false)) + .def_static("add_enforcement_literals", &AddEnforcementLiterals, + py::arg("index"), py::arg("literals"), + py::arg("model_proto").none(false)) + .def_static("add_bounded_linear_expression_to_model", + &AddBoundedLinearExpressionToModel, py::arg("ble"), + py::arg("model_proto")); py::class_>( m, "LinearExpr", DOC(operations_research, sat, python, LinearExpr)) @@ -894,31 +989,27 @@ PYBIND11_MODULE(cp_model_helper, m) { py::init>, int64_t, double>()) .def( "__add__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Add)) .def( "__add__", - [](py::object self, int64_t cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); }, DOC(operations_research, sat, python, LinearExpr, AddInt)) .def( "__add__", - [](py::object self, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddFloatInPlace(cst) : expr->AddFloat(cst); }, @@ -926,10 +1017,9 @@ PYBIND11_MODULE(cp_model_helper, m) { DOC(operations_research, sat, python, LinearExpr, AddFloat)) .def( "__radd__", - [](py::object self, int64_t cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); }, @@ -937,10 +1027,9 @@ PYBIND11_MODULE(cp_model_helper, m) { DOC(operations_research, sat, python, LinearExpr, AddInt)) .def( "__radd__", - [](py::object self, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddFloatInPlace(cst) : expr->AddFloat(cst); }, @@ -971,11 +1060,9 @@ PYBIND11_MODULE(cp_model_helper, m) { DOC(operations_research, sat, python, LinearExpr, AddFloat)) .def( "__sub__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddInPlace(other->Neg()) : expr->Sub(other); }, @@ -983,10 +1070,9 @@ PYBIND11_MODULE(cp_model_helper, m) { DOC(operations_research, sat, python, LinearExpr, Sub)) .def( "__sub__", - [](py::object self, int64_t cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddIntInPlace(-cst) : expr->SubInt(cst); }, @@ -994,10 +1080,9 @@ PYBIND11_MODULE(cp_model_helper, m) { DOC(operations_research, sat, python, LinearExpr, SubInt)) .def( "__sub__", - [](py::object self, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + const int num_uses = Py_REFCNT(py::cast(expr).ptr()); return (num_uses == 4) ? expr->AddFloatInPlace(-cst) : expr->SubFloat(cst); }, @@ -1067,52 +1152,99 @@ PYBIND11_MODULE(cp_model_helper, m) { .def("Not", &Literal::negated) .def("Index", &Literal::index); - // Memory management: - // - The BaseIntVar owns the NotBooleanVariable and keeps a shared_ptr to it. - // - The NotBooleanVariable is created on demand, and is deleted when the base - // variable is deleted. It holds a weak_ptr to the base variable. - py::class_, Literal>( - m, "BaseIntVar", DOC(operations_research, sat, python, BaseIntVar)) - .def(py::init()) // Integer variable. - .def(py::init()) // Potential Boolean variable. + // IntVar and NotBooleanVariable both hold a shared_ptr to the model_proto. + py::class_, Literal>( + m, "IntVar", DOC(operations_research, sat, python, IntVar)) + .def(py::init, int>()) + .def(py::init>()) // new variable. .def_property_readonly( - "index", &BaseIntVar::index, - DOC(operations_research, sat, python, BaseIntVar, index)) + "proto", &IntVar::proto, py::return_value_policy::reference, + py::keep_alive<1, 0>() + // DOC(operations_research, sat, python, IntVar, proto) + ) .def_property_readonly( - "is_boolean", &BaseIntVar::is_boolean, - DOC(operations_research, sat, python, BaseIntVar, is_boolean)) - .def("__str__", &BaseIntVar::ToString) - .def("__repr__", &BaseIntVar::DebugString) + "model_proto", &IntVar::model_proto + // DOC(operations_research, sat, python, IntVar, model_proto) + ) + .def_property_readonly( + "index", &IntVar::index, py::return_value_policy::reference, + DOC(operations_research, sat, python, IntVar, index)) + .def_property_readonly( + "is_boolean", &IntVar::is_boolean, + DOC(operations_research, sat, python, IntVar, is_boolean)) + .def_property( + "name", &IntVar::name, &IntVar::SetName //, py::arg("name") + // DOC(operations_research, + // sat, python, IntVar, name) + ) + .def( + "with_name", + [](std::shared_ptr self, const std::string& name) { + self->SetName(name); + return self; + }, + py::arg("name")) + .def_property( + "domain", &IntVar::domain, &IntVar::SetDomain //, py::arg("domain") + // DOC(operations_research, sat, python, IntVar, domain) + ) + .def( + "with_domain", + [](std::shared_ptr self, const Domain& domain) { + self->SetDomain(domain); + return self; + }, + py::arg("domain")) + .def("__str__", &IntVar::ToString) + .def("__repr__", &IntVar::DebugString) .def( "negated", - [](std::shared_ptr self) { + [](std::shared_ptr self) { if (!self->is_boolean()) { ThrowError(PyExc_TypeError, "negated() is only supported for Boolean variables."); } return self->negated(); }, - DOC(operations_research, sat, python, BaseIntVar, negated)) + DOC(operations_research, sat, python, IntVar, negated)) .def( "__invert__", - [](std::shared_ptr self) { + [](std::shared_ptr self) { if (!self->is_boolean()) { ThrowError(PyExc_TypeError, "negated() is only supported for Boolean variables."); } return self->negated(); }, - DOC(operations_research, sat, python, BaseIntVar, negated)) + DOC(operations_research, sat, python, IntVar, negated)) + .def("__copy__", + [](const std::shared_ptr& self) { + return std::make_shared(self->model_proto(), + self->index()); + }) + .def(py::pickle( + [](std::shared_ptr p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p->model_proto(), p->index()); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2) throw std::runtime_error("Invalid state!"); + + return std::make_shared( + t[0].cast>(), t[1].cast()); + })) // PEP8 Compatibility. + .def("Name", &IntVar::name) + .def("Proto", &IntVar::proto) .def("Not", - [](std::shared_ptr self) { + [](std::shared_ptr self) { if (!self->is_boolean()) { ThrowError(PyExc_TypeError, "negated() is only supported for Boolean variables."); } return self->negated(); }) - .def("Index", &BaseIntVar::index); + .def("Index", &IntVar::index); py::class_, Literal>( m, "NotBooleanVariable", @@ -1120,61 +1252,32 @@ PYBIND11_MODULE(cp_model_helper, m) { .def_property_readonly( "index", [](std::shared_ptr not_var) -> int { - if (!not_var->ok()) { - ThrowError(PyExc_ReferenceError, - "The base variable is not valid."); - } return not_var->index(); }, DOC(operations_research, sat, python, NotBooleanVariable, index)) .def("__str__", [](std::shared_ptr not_var) -> std::string { - if (!not_var->ok()) { - ThrowError(PyExc_ReferenceError, - "The base variable is not valid."); - } return not_var->ToString(); }) .def("__repr__", [](std::shared_ptr not_var) -> std::string { - if (!not_var->ok()) { - ThrowError(PyExc_ReferenceError, - "The base variable is not valid."); - } return not_var->DebugString(); }) .def( "negated", [](std::shared_ptr not_var) - -> std::shared_ptr { - if (!not_var->ok()) { - ThrowError(PyExc_ReferenceError, - "The base variable is not valid."); - } - return not_var->negated(); - }, + -> std::shared_ptr { return not_var->negated(); }, DOC(operations_research, sat, python, NotBooleanVariable, negated)) .def( "__invert__", [](std::shared_ptr not_var) - -> std::shared_ptr { - if (!not_var->ok()) { - ThrowError(PyExc_ReferenceError, - "The base variable is not valid."); - } - return not_var->negated(); - }, + -> std::shared_ptr { return not_var->negated(); }, DOC(operations_research, sat, python, NotBooleanVariable, negated)) + // PEP8 Compatibility. .def( "Not", [](std::shared_ptr not_var) - -> std::shared_ptr { - if (!not_var->ok()) { - ThrowError(PyExc_ReferenceError, - "The base variable is not valid."); - } - return not_var->negated(); - }, + -> std::shared_ptr { return not_var->negated(); }, DOC(operations_research, sat, python, NotBooleanVariable, negated)); py::class_>( diff --git a/ortools/sat/python/cp_model_helper_test.py b/ortools/sat/python/cp_model_helper_test.py index d5901787a7..09a39afc1d 100644 --- a/ortools/sat/python/cp_model_helper_test.py +++ b/ortools/sat/python/cp_model_helper_test.py @@ -18,10 +18,10 @@ import sys from absl.testing import absltest -from google.protobuf import text_format -from ortools.sat import cp_model_pb2 -from ortools.sat import sat_parameters_pb2 +from ortools.sat.python import cp_model_builder from ortools.sat.python import cp_model_helper as cmh +from ortools.sat.python import sat_parameters_builder +from ortools.util.python import sorted_interval_list class Callback(cmh.SolutionCallback): @@ -47,19 +47,6 @@ class BestBoundCallback: self.best_bound = bb -class TestIntVar(cmh.BaseIntVar): - - def __init__(self, index: int, name: str, is_boolean: bool = False) -> None: - cmh.BaseIntVar.__init__(self, index, is_boolean) - self._name = name - - def __str__(self) -> str: - return self._name - - def __repr__(self) -> str: - return self._name - - class CpModelHelperTest(absltest.TestCase): def tearDown(self) -> None: @@ -71,8 +58,8 @@ class CpModelHelperTest(absltest.TestCase): variables { domain: [ -10, 10 ] } variables { domain: [ -5, -5, 3, 6 ] } """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cp_model_builder.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) d0 = cmh.CpSatHelper.variable_domain(model.variables[0]) d1 = cmh.CpSatHelper.variable_domain(model.variables[1]) @@ -112,13 +99,13 @@ class CpModelHelperTest(absltest.TestCase): coeffs: -1 scaling_factor: -1 }""" - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cp_model_builder.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cp_model_builder.OPTIMAL, response_wrapper.status()) self.assertEqual(30.0, response_wrapper.objective_value()) def test_simple_solve_with_core(self): @@ -153,20 +140,21 @@ class CpModelHelperTest(absltest.TestCase): coeffs: -1 scaling_factor: -1 }""" - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cp_model_builder.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) - parameters = sat_parameters_pb2.SatParameters(optimize_with_core=True) + parameters = sat_parameters_builder.SatParameters() + parameters.optimize_with_core = True solve_wrapper = cmh.SolveWrapper() solve_wrapper.set_parameters(parameters) response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cp_model_builder.OPTIMAL, response_wrapper.status()) self.assertEqual(30.0, response_wrapper.objective_value()) def test_simple_solve_with_proto_api(self): - model = cp_model_pb2.CpModelProto() + model = cp_model_builder.CpModelProto() x = model.variables.add() x.domain.extend([-10, 10]) y = model.variables.add() @@ -184,7 +172,7 @@ class CpModelHelperTest(absltest.TestCase): solve_wrapper = cmh.SolveWrapper() response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cp_model_builder.OPTIMAL, response_wrapper.status()) self.assertEqual(30.0, response_wrapper.objective_value()) self.assertEqual(30.0, response_wrapper.best_objective_bound()) self.assertRaises(TypeError, response_wrapper.value, None) @@ -198,19 +186,19 @@ class CpModelHelperTest(absltest.TestCase): constraints { linear { vars: 0 vars: 1 coeffs: 1 coeffs: 1 domain: 6 domain: 6 } } """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cp_model_builder.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() callback = Callback() solve_wrapper.add_solution_callback(callback) - params = sat_parameters_pb2.SatParameters() + params = sat_parameters_builder.SatParameters() params.enumerate_all_solutions = True solve_wrapper.set_parameters(params) response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) self.assertEqual(5, callback.solution_count()) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cp_model_builder.OPTIMAL, response_wrapper.status()) def test_best_bound_callback(self): model_string = """ @@ -225,13 +213,13 @@ class CpModelHelperTest(absltest.TestCase): offset: 0.6 } """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cp_model_builder.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() best_bound_callback = BestBoundCallback() solve_wrapper.add_best_bound_callback(best_bound_callback.new_best_bound) - params = sat_parameters_pb2.SatParameters() + params = sat_parameters_builder.SatParameters() params.num_workers = 1 params.linearization_level = 2 params.log_search_progress = True @@ -239,7 +227,7 @@ class CpModelHelperTest(absltest.TestCase): response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) self.assertEqual(2.6, best_bound_callback.best_bound) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cp_model_builder.OPTIMAL, response_wrapper.status()) def test_model_stats(self): model_string = """ @@ -275,15 +263,16 @@ class CpModelHelperTest(absltest.TestCase): } name: 'testModelStats' """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cp_model_builder.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) stats = cmh.CpSatHelper.model_stats(model) self.assertTrue(stats) def test_int_lin_expr(self): - x = TestIntVar(0, "x") + model = cp_model_builder.CpModelProto() + x = cmh.IntVar(model).with_name("x") self.assertTrue(x.is_integer()) - self.assertIsInstance(x, cmh.BaseIntVar) + self.assertIsInstance(x, cmh.IntVar) self.assertIsInstance(x, cmh.LinearExpr) e1 = x + 2 self.assertTrue(e1.is_integer()) @@ -291,7 +280,7 @@ class CpModelHelperTest(absltest.TestCase): e2 = 3 + x self.assertTrue(e2.is_integer()) self.assertEqual(str(e2), "(x + 3)") - y = TestIntVar(1, "y") + y = cmh.IntVar(model).with_name("y") e3 = y * 5 self.assertTrue(e3.is_integer()) self.assertEqual(str(e3), "(5 * y)") @@ -304,7 +293,8 @@ class CpModelHelperTest(absltest.TestCase): e6 = x - 2 * y self.assertTrue(e6.is_integer()) self.assertEqual(str(e6), "(x + (-2 * y))") - z = TestIntVar(2, "z", True) + z = cmh.IntVar(model).with_name("z") + z.domain = sorted_interval_list.Domain.from_values([0, 1]) e7 = -z self.assertTrue(e7.is_integer()) self.assertEqual(str(e7), "(-z)") @@ -326,9 +316,10 @@ class CpModelHelperTest(absltest.TestCase): self.assertEqual(str(e12), "(x + (-y) + (-2 * z))") def test_float_lin_expr(self): - x = TestIntVar(0, "x") + model = cp_model_builder.CpModelProto() + x = cmh.IntVar(model).with_name("x") self.assertTrue(x.is_integer()) - self.assertIsInstance(x, TestIntVar) + self.assertIsInstance(x, cmh.IntVar) self.assertIsInstance(x, cmh.LinearExpr) e1 = x + 2.5 self.assertFalse(e1.is_integer()) @@ -336,7 +327,7 @@ class CpModelHelperTest(absltest.TestCase): e2 = 3.1 + x self.assertFalse(e2.is_integer()) self.assertEqual(str(e2), "(x + 3.1)") - y = TestIntVar(1, "y") + y = cmh.IntVar(model).with_name("y") e3 = y * 5.2 self.assertFalse(e3.is_integer()) self.assertEqual(str(e3), "(5.2 * y)") @@ -353,7 +344,7 @@ class CpModelHelperTest(absltest.TestCase): self.assertFalse(e7.is_integer()) self.assertEqual(str(e7), "(x + (-(2.4 * y)))") - z = TestIntVar(2, "z") + z = cmh.IntVar(model).with_name("z") e8 = cmh.LinearExpr.sum([x, y, z, -2]) self.assertTrue(e8.is_integer()) self.assertEqual(str(e8), "(x + y + z - 2)") diff --git a/ortools/sat/python/cp_model_numbers.py b/ortools/sat/python/cp_model_numbers.py deleted file mode 100644 index 26b7928df5..0000000000 --- a/ortools/sat/python/cp_model_numbers.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2010-2025 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""helpers methods for the cp_model module.""" - -import numbers -from typing import Any -import numpy as np - - -INT_MIN = -9223372036854775808 # hardcoded to be platform independent. -INT_MAX = 9223372036854775807 - - -def is_boolean(x: Any) -> bool: - """Checks if the x is a boolean.""" - if isinstance(x, bool): - return True - if isinstance(x, np.bool_): - return True - return False - - -def assert_is_zero_or_one(x: Any) -> int: - """Asserts that x is 0 or 1 and returns it as an int.""" - if not isinstance(x, numbers.Integral): - raise TypeError(f"Not a boolean: {x} of type {type(x)}") - x_as_int = int(x) - if x_as_int < 0 or x_as_int > 1: - raise TypeError(f"Not a boolean: {x}") - return x_as_int - - -def to_capped_int64(v: int) -> int: - """Restrict v within [INT_MIN..INT_MAX] range.""" - if v > INT_MAX: - return INT_MAX - if v < INT_MIN: - return INT_MIN - return v - - -def capped_subtraction(x: int, y: int) -> int: - """Saturated arithmetics. Returns x - y truncated to the int64_t range.""" - if y == 0: - return x - if x == y: - if x == INT_MAX or x == INT_MIN: - raise OverflowError("Integer NaN: subtracting INT_MAX or INT_MIN to itself") - return 0 - if x == INT_MAX or x == INT_MIN: - return x - if y == INT_MAX: - return INT_MIN - if y == INT_MIN: - return INT_MAX - return to_capped_int64(x - y) diff --git a/ortools/sat/python/cp_model_numbers_test.py b/ortools/sat/python/cp_model_numbers_test.py deleted file mode 100644 index 1e0b91a0fc..0000000000 --- a/ortools/sat/python/cp_model_numbers_test.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2010-2025 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -from absl.testing import absltest -import numpy as np - -from ortools.sat.python import cp_model_numbers as cmn - - -class CpModelNumbersTest(absltest.TestCase): - - def tearDown(self) -> None: - super().tearDown() - sys.stdout.flush() - - def test_is_boolean(self): - self.assertTrue(cmn.is_boolean(True)) - self.assertTrue(cmn.is_boolean(False)) - self.assertFalse(cmn.is_boolean(1)) - self.assertFalse(cmn.is_boolean(0)) - self.assertTrue(cmn.is_boolean(np.bool_(1))) - self.assertTrue(cmn.is_boolean(np.bool_(0))) - - def test_to_capped_int64(self): - self.assertEqual(cmn.to_capped_int64(cmn.INT_MAX), cmn.INT_MAX) - self.assertEqual(cmn.to_capped_int64(cmn.INT_MAX + 1), cmn.INT_MAX) - self.assertEqual(cmn.to_capped_int64(cmn.INT_MIN), cmn.INT_MIN) - self.assertEqual(cmn.to_capped_int64(cmn.INT_MIN - 1), cmn.INT_MIN) - self.assertEqual(cmn.to_capped_int64(15), 15) - - def test_capped_subtraction(self): - self.assertEqual(cmn.capped_subtraction(10, 5), 5) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MIN, 5), cmn.INT_MIN) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MIN, -5), cmn.INT_MIN) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MAX, 5), cmn.INT_MAX) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MAX, -5), cmn.INT_MAX) - self.assertEqual(cmn.capped_subtraction(2, cmn.INT_MIN), cmn.INT_MAX) - self.assertEqual(cmn.capped_subtraction(2, cmn.INT_MAX), cmn.INT_MIN) - self.assertRaises( - OverflowError, cmn.capped_subtraction, cmn.INT_MAX, cmn.INT_MAX - ) - self.assertRaises( - OverflowError, cmn.capped_subtraction, cmn.INT_MIN, cmn.INT_MIN - ) - self.assertRaises(TypeError, cmn.capped_subtraction, 5, "dummy") - self.assertRaises(TypeError, cmn.capped_subtraction, "dummy", 5) - - -if __name__ == "__main__": - absltest.main() diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index 7e1b1b1554..aa647a6bb2 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -21,8 +21,8 @@ from absl.testing import absltest import numpy as np import pandas as pd -from ortools.sat import cp_model_pb2 from ortools.sat.python import cp_model +from ortools.sat.python import cp_model_builder from ortools.sat.python import cp_model_helper as cmh @@ -184,6 +184,14 @@ class CpModelTest(absltest.TestCase): super().tearDown() sys.stdout.flush() + def test_is_boolean(self): + self.assertTrue(cp_model.arg_is_boolean(True)) + self.assertTrue(cp_model.arg_is_boolean(False)) + self.assertFalse(cp_model.arg_is_boolean(1)) + self.assertFalse(cp_model.arg_is_boolean(0)) + self.assertTrue(cp_model.arg_is_boolean(np.bool_(1))) + self.assertTrue(cp_model.arg_is_boolean(np.bool_(0))) + def test_create_integer_variable(self) -> None: model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") @@ -230,6 +238,9 @@ class CpModelTest(absltest.TestCase): one = model.new_constant(1) self.assertEqual("1", str(one)) self.assertEqual("not(1)", str(~one)) + no_name = model.new_bool_var("") + self.assertEqual("b4", str(no_name)) + self.assertEqual("not(b4)", str(~no_name)) z = model.new_int_var(0, 2, "z") self.assertRaises(TypeError, z.negated) self.assertRaises(TypeError, z.__invert__) @@ -284,14 +295,14 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, solver.float_value, None) self.assertRaises(TypeError, solver.boolean_value, None) - def test_linear_constraint(self) -> None: + def test_empty_linear_constraint(self) -> None: model = cp_model.CpModel() model.add_linear_constraint(5, 0, 10) model.add_linear_constraint(-1, 0, 10) self.assertLen(model.proto.constraints, 2) - self.assertTrue(model.proto.constraints[0].HasField("bool_and")) + self.assertTrue(model.proto.constraints[0].has_bool_and()) self.assertEmpty(model.proto.constraints[0].bool_and.literals) - self.assertTrue(model.proto.constraints[1].HasField("bool_or")) + self.assertTrue(model.proto.constraints[1].has_bool_or()) self.assertEmpty(model.proto.constraints[1].bool_or.literals) def test_linear_non_equal(self) -> None: @@ -315,6 +326,17 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2, ct.linear.domain[0]) self.assertEqual(2, ct.linear.domain[1]) + def test_large_constants(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(-10, 10, "x") + ct = model.add(x * 50000000000 == 1234567890).proto + self.assertLen(ct.linear.vars, 1) + self.assertLen(ct.linear.coeffs, 1) + self.assertEqual(50000000000, ct.linear.coeffs[0]) + self.assertLen(ct.linear.domain, 2) + self.assertEqual(1234567890, ct.linear.domain[0]) + self.assertEqual(1234567890, ct.linear.domain[1]) + def testGe(self) -> None: model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") @@ -476,7 +498,7 @@ class CpModelTest(absltest.TestCase): model.add(x * 2 - 1 * y == 1) model.minimize(x * 1 - 2 * y + 3) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(5, solver.value(x)) self.assertEqual(15, solver.value(x * 3)) self.assertEqual(6, solver.value(1 + x)) @@ -488,7 +510,7 @@ class CpModelTest(absltest.TestCase): y = model.new_int_var(0, 10, "y") model.maximize(x.negated() * 3.5 + x.negated() - y + 2 * y + 1.6) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertFalse(solver.boolean_value(x)) self.assertTrue(solver.boolean_value(x.negated())) self.assertEqual(-10, solver.value(-y)) @@ -505,7 +527,7 @@ class CpModelTest(absltest.TestCase): + cp_model.LinearExpr.weighted_sum([x3, x4.negated()], [2, 4]) ) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(5, solver.value(3 + 2 * x1)) self.assertEqual(3, solver.value(x1 + x2 + x3)) self.assertEqual(1, solver.value(cp_model.LinearExpr.sum([x1, x2, x3, 0, -2]))) @@ -525,7 +547,7 @@ class CpModelTest(absltest.TestCase): model.add(2 * x - y == 1) model.maximize(x - 2 * y + 3) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(-4, solver.value(x)) self.assertEqual(-9, solver.value(y)) self.assertEqual(17, solver.objective_value) @@ -536,7 +558,7 @@ class CpModelTest(absltest.TestCase): model.add(x >= -1) model.minimize(10) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(10, solver.objective_value) def test_maximize_constant(self) -> None: @@ -545,7 +567,7 @@ class CpModelTest(absltest.TestCase): model.add(x >= -1) model.maximize(5) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(5, solver.objective_value) def test_add_true(self) -> None: @@ -554,7 +576,7 @@ class CpModelTest(absltest.TestCase): model.add(3 >= -1) model.minimize(x) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(-10, solver.value(x)) def test_add_false(self) -> None: @@ -563,7 +585,8 @@ class CpModelTest(absltest.TestCase): model.add(3 <= -1) model.minimize(x) solver = cp_model.CpSolver() - self.assertEqual("INFEASIBLE", solver.status_name(solver.solve(model))) + status: cp_model_builder.CpSolverStatus = solver.solve(model) + self.assertEqual("INFEASIBLE", status.name) def test_sum(self) -> None: model = cp_model.CpModel() @@ -838,7 +861,7 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[0].linear.vars, 1) self.assertEqual(x[3].index, model.proto.constraints[0].linear.vars[0]) self.assertEqual(1, model.proto.constraints[0].linear.coeffs[0]) - self.assertEqual([2, 2], model.proto.constraints[0].linear.domain) + self.assertEqual([2, 2], list(model.proto.constraints[0].linear.domain)) def test_affine_element(self) -> None: model = cp_model.CpModel() @@ -1285,12 +1308,12 @@ class CpModelTest(absltest.TestCase): self.assertEqual(~i.size_expr(), ~y) self.assertRaises(TypeError, i.start_expr().negated) - proto = cp_model_pb2.LinearExpressionProto() + proto = cp_model_builder.LinearExpressionProto() proto.vars.append(x.index) proto.coeffs.append(1) proto.vars.append(y.index) proto.coeffs.append(2) - expr1 = model.rebuild_from_linear_expression_proto(proto) + expr1 = cp_model.rebuild_from_linear_expression_proto(proto, model.proto) canonical_expr1 = cmh.FlatIntExpr(expr1) self.assertEqual(canonical_expr1.vars[0], x) self.assertEqual(canonical_expr1.vars[1], y) @@ -1301,7 +1324,7 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, canonical_expr1.vars[0].negated) proto.offset = 2 - expr2 = model.rebuild_from_linear_expression_proto(proto) + expr2 = cp_model.rebuild_from_linear_expression_proto(proto, model.proto) canonical_expr2 = cmh.FlatIntExpr(expr2) self.assertEqual(canonical_expr2.vars[0], x) self.assertEqual(canonical_expr2.vars[1], y) @@ -1474,7 +1497,7 @@ class CpModelTest(absltest.TestCase): self.assertEqual(repr(i), "i(start = x, size = 2, end = y)") b = model.new_bool_var("b") self.assertEqual(repr(b), "b(0..1)") - self.assertEqual(repr(~b), "NotBooleanVariable(index=3)") + self.assertEqual(repr(~b), "NotBooleanVariable(var_index=3)") x1 = model.new_int_var(0, 4, "x1") y1 = model.new_int_var(0, 3, "y1") j = model.new_optional_interval_var(x1, 2, y1, b, "j") @@ -1486,16 +1509,6 @@ class CpModelTest(absltest.TestCase): repr(k), "k(start = x2, size = 2, end = y2, is_present = not(b))" ) - def testDisplayBounds(self) -> None: - self.assertEqual("10..20", cp_model.display_bounds([10, 20])) - self.assertEqual("10", cp_model.display_bounds([10, 10])) - self.assertEqual("10..15, 20..30", cp_model.display_bounds([10, 15, 20, 30])) - - def test_short_name(self) -> None: - model = cp_model.CpModel() - model.proto.variables.add(domain=[5, 10]) - self.assertEqual("[5..10]", cp_model.short_name(model.proto, 0)) - def test_integer_expression_errors(self) -> None: model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") @@ -1525,10 +1538,23 @@ class CpModelTest(absltest.TestCase): model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") y = model.new_int_var(-10, 10, "y") + b = model.new_bool_var("b") model.add_linear_constraint(x + 2 * y, 0, 10) model.minimize(y) solver = cp_model.CpSolver() self.assertRaises(RuntimeError, solver.value, x) + self.assertRaises(RuntimeError, solver.boolean_value, b) + self.assertRaises(RuntimeError, lambda: solver.best_objective_bound) + self.assertRaises(RuntimeError, lambda: solver.deterministic_time) + self.assertRaises(RuntimeError, lambda: solver.num_boolean_propagations) + self.assertRaises(RuntimeError, lambda: solver.num_booleans) + self.assertRaises(RuntimeError, lambda: solver.num_branches) + self.assertRaises(RuntimeError, lambda: solver.num_conflicts) + self.assertRaises(RuntimeError, lambda: solver.num_integer_propagations) + self.assertRaises(RuntimeError, lambda: solver.objective_value) + self.assertRaises(RuntimeError, lambda: solver.response_proto) + self.assertRaises(RuntimeError, lambda: solver.user_time) + self.assertRaises(RuntimeError, lambda: solver.wall_time) solver.solve(model) self.assertRaises(TypeError, solver.value, "not_a_variable") self.assertRaises(TypeError, model.add_bool_or, [x, y]) @@ -1885,7 +1911,7 @@ class CpModelTest(absltest.TestCase): with self.assertRaises(ValueError): new_model.get_interval_var_from_proto_index(-1) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): new_model.get_bool_var_from_proto_index(x.index) with self.assertRaises(ValueError): @@ -1908,8 +1934,8 @@ class CpModelTest(absltest.TestCase): deepcopy_c = copy.deepcopy(c) self.assertIsNot(deepcopy_c.model, c.model) self.assertIsNot(deepcopy_c.var, c.var) - self.assertIs(deepcopy_c.model.proto, deepcopy_c.var.model_proto) - self.assertIs( + self.assertEqual(deepcopy_c.model.proto, deepcopy_c.var.model_proto) + self.assertEqual( deepcopy_c.var, deepcopy_c.model.get_int_var_from_proto_index(x.index), ) diff --git a/ortools/sat/python/gen_cp_model_builder_pybind.cc b/ortools/sat/python/gen_cp_model_builder_pybind.cc new file mode 100644 index 0000000000..bc19b8f5d0 --- /dev/null +++ b/ortools/sat/python/gen_cp_model_builder_pybind.cc @@ -0,0 +1,63 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/flags/parse.h" +#include "absl/flags/usage.h" +#include "absl/log/die_if_null.h" +#include "absl/log/initialize.h" +#include "absl/strings/str_format.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/python/wrappers.h" + +namespace operations_research::sat::python { + +void ParseAndGenerate() { + absl::PrintF( + R"( + +// This is a generated file, do not edit. +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "google/protobuf/text_format.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "ortools/port/proto_utils.h" +#include "ortools/sat/cp_model.pb.h" + +namespace py = ::pybind11; + +namespace operations_research::sat::python { + +PYBIND11_MODULE(cp_model_builder, py_module) { +%s +} // PYBIND11_MODULE + +} // namespace operations_research::sat::python +)", + GeneratePybindCode({ABSL_DIE_IF_NULL(CpModelProto::descriptor()), + ABSL_DIE_IF_NULL(CpSolverResponse::descriptor())})); +} + +} // namespace operations_research::sat::python + +int main(int argc, char* argv[]) { + // We do not use InitGoogle() to avoid linking with or-tools as this would + // create a circular dependency. + absl::InitializeLog(); + absl::SetProgramUsageMessage(argv[0]); + absl::ParseCommandLine(argc, argv); + operations_research::sat::python::ParseAndGenerate(); + return 0; +} diff --git a/ortools/sat/python/gen_sat_parameters_builder_pybind.cc b/ortools/sat/python/gen_sat_parameters_builder_pybind.cc new file mode 100644 index 0000000000..cf2596abce --- /dev/null +++ b/ortools/sat/python/gen_sat_parameters_builder_pybind.cc @@ -0,0 +1,59 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/flags/parse.h" +#include "absl/flags/usage.h" +#include "absl/log/die_if_null.h" +#include "absl/log/initialize.h" +#include "absl/strings/str_format.h" +#include "ortools/sat/python/wrappers.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace operations_research::sat::python { + +void ParseAndGenerate() { + absl::PrintF( + R"( + +// This is a generated file, do not edit. +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "google/protobuf/text_format.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "ortools/port/proto_utils.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace py = ::pybind11; +namespace operations_research::sat::python { +PYBIND11_MODULE(sat_parameters_builder, py_module) { +%s +} // PYBIND11_MODULE +} // namespace operations_research::sat::python +)", + GeneratePybindCode({ABSL_DIE_IF_NULL(SatParameters::descriptor())})); +} + +} // namespace operations_research::sat::python + +int main(int argc, char* argv[]) { + // We do not use InitGoogle() to avoid linking with or-tools as this would + // create a circular dependency. + absl::InitializeLog(); + absl::SetProgramUsageMessage(argv[0]); + absl::ParseCommandLine(argc, argv); + operations_research::sat::python::ParseAndGenerate(); + return 0; +} diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index f8c2954f62..0b87598b67 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -26,6 +26,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" #include "ortools/util/fp_roundtrip_conv.h" #include "ortools/util/sorted_interval_list.h" @@ -143,8 +145,7 @@ void FloatExprVisitor::AddToProcess(std::shared_ptr expr, void FloatExprVisitor::AddConstant(double constant) { offset_ += constant; } -void FloatExprVisitor::AddVarCoeff(std::shared_ptr var, - double coeff) { +void FloatExprVisitor::AddVarCoeff(std::shared_ptr var, double coeff) { canonical_terms_[var] += coeff; } @@ -156,7 +157,7 @@ void FloatExprVisitor::ProcessAll() { } } -double FloatExprVisitor::Process(std::vector>* vars, +double FloatExprVisitor::Process(std::vector>* vars, std::vector* coeffs) { ProcessAll(); @@ -316,7 +317,7 @@ std::string FlatIntExpr::DebugString() const { return absl::StrCat( "FlatIntExpr([", absl::StrJoin(vars_, ", ", - [](std::string* out, std::shared_ptr var) { + [](std::string* out, std::shared_ptr var) { absl::StrAppend(out, var->DebugString()); }), "], [", absl::StrJoin(coeffs_, ", "), "], ", offset_, ")"); @@ -745,8 +746,7 @@ void IntExprVisitor::AddToProcess(std::shared_ptr expr, void IntExprVisitor::AddConstant(int64_t constant) { offset_ += constant; } -void IntExprVisitor::AddVarCoeff(std::shared_ptr var, - int64_t coeff) { +void IntExprVisitor::AddVarCoeff(std::shared_ptr var, int64_t coeff) { canonical_terms_[var] += coeff; } @@ -759,7 +759,7 @@ bool IntExprVisitor::ProcessAll() { return true; } -bool IntExprVisitor::Process(std::vector>* vars, +bool IntExprVisitor::Process(std::vector>* vars, std::vector* coeffs, int64_t* offset) { if (!ProcessAll()) return false; vars->clear(); @@ -789,64 +789,146 @@ bool IntExprVisitor::Evaluate(const CpSolverResponse& solution, // the same index and different models. int64_t Literal::Hash() const { return absl::HashOf(index()); } -bool BaseIntVarComparator::operator()(std::shared_ptr lhs, - std::shared_ptr rhs) const { +bool IntVarComparator::operator()(std::shared_ptr lhs, + std::shared_ptr rhs) const { return lhs->index() < rhs->index(); } -BaseIntVar::BaseIntVar(int index, bool is_boolean) - : index_(index), is_boolean_(is_boolean) {} - -std::shared_ptr BaseIntVar::negated() { - if (negated_ == nullptr) { - std::shared_ptr self = - std::static_pointer_cast(shared_from_this()); - negated_ = std::make_shared(self); +std::string IntVar::name() const { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return ""; } - return negated_; + return model_proto_->variables(index_).name(); } -int NotBooleanVariable::index() const { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return -var->index() - 1; +void IntVar::SetName(const std::string& name) { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return; + } + if (name.empty()) { + model_proto_->mutable_variables(index_)->clear_name(); + } else { + model_proto_->mutable_variables(index_)->set_name(name); + } } +Domain IntVar::domain() const { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return Domain(); + } + return ReadDomainFromProto(model_proto_->variables(index_)); +} + +void IntVar::SetDomain(const Domain& domain) { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return; + } + FillDomainInProto(domain, model_proto_->mutable_variables(index_)); +} + +std::shared_ptr IntVar::model_proto() const { + return model_proto_; +} + +IntegerVariableProto* IntVar::proto() const { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return nullptr; + } + return model_proto_->mutable_variables(index_); +} + +bool IntVar::is_boolean() const { + IntegerVariableProto* var_proto = proto(); + if (var_proto == nullptr) return false; + return var_proto->domain_size() == 2 && var_proto->domain(0) >= 0 && + var_proto->domain(1) <= 1; +} + +bool IntVar::is_fixed() const { + IntegerVariableProto* var_proto = proto(); + if (var_proto == nullptr) return false; + return var_proto->domain_size() == 2 && + var_proto->domain(0) == var_proto->domain(1); +} + +std::shared_ptr IntVar::negated() const { + return std::make_shared(model_proto_, index_); +} + +namespace { +std::string VarDomainToString(IntegerVariableProto* var_proto) { + std::string domain_str; + for (int i = 0; i < var_proto->domain_size(); i += 2) { + const int64_t lb = var_proto->domain(i); + const int64_t ub = var_proto->domain(i + 1); + if (i > 0) absl::StrAppend(&domain_str, ", "); + if (lb == ub) { + absl::StrAppend(&domain_str, lb); + } else { + absl::StrAppend(&domain_str, lb, "..", ub); + } + } + return domain_str; +} + +} // namespace + +std::string IntVar::ToString() const { + std::string var_name = name(); + IntegerVariableProto* var_proto = proto(); + if (var_name.empty()) { + if (is_fixed() && var_proto != nullptr && var_proto->domain_size() >= 2) { + return absl::StrCat(var_proto->domain(0)); + } else if (is_boolean()) { + return absl::StrCat("b", index_); + } else { + return absl::StrCat("x", index_); + } + } + return var_name; +} + +std::string IntVar::DebugString() const { + std::string var_name = name(); + if (var_name.empty()) { + if (is_boolean()) { + var_name = absl::StrCat("b", index_); + } else { + var_name = absl::StrCat("x", index_); + } + } + IntegerVariableProto* var_proto = proto(); + if (var_proto == nullptr) return var_name; + return absl::StrCat(var_name, "(", VarDomainToString(var_proto), ")"); +} + +int NotBooleanVariable::index() const { return NegatedRef(var_index_); } + /** * Returns the negation of the current literal, that is the original Boolean * variable. */ -std::shared_ptr NotBooleanVariable::negated() { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return var; +std::shared_ptr NotBooleanVariable::negated() const { + return std::make_shared(model_proto_, var_index_); } bool NotBooleanVariable::VisitAsInt(IntExprVisitor& lin, int64_t c) { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - lin.AddVarCoeff(var, -c); + lin.AddVarCoeff(std::make_shared(model_proto_, var_index_), -c); lin.AddConstant(c); return true; } void NotBooleanVariable::VisitAsFloat(FloatExprVisitor& lin, double c) { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - lin.AddVarCoeff(var, -c); + lin.AddVarCoeff(std::make_shared(model_proto_, var_index_), -c); lin.AddConstant(c); } std::string NotBooleanVariable::ToString() const { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return absl::StrCat("not(", var->ToString(), ")"); + return absl::StrCat("not(", negated()->ToString(), ")"); } std::string NotBooleanVariable::DebugString() const { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return absl::StrCat("NotBooleanVariable(index=", var->index(), ")"); + return absl::StrCat("NotBooleanVariable(var_index=", var_index_, ")"); } BoundedLinearExpression::BoundedLinearExpression( @@ -868,7 +950,7 @@ BoundedLinearExpression::BoundedLinearExpression( } const Domain& BoundedLinearExpression::bounds() const { return bounds_; } -const std::vector>& BoundedLinearExpression::vars() +const std::vector>& BoundedLinearExpression::vars() const { return vars_; } @@ -956,7 +1038,7 @@ std::string BoundedLinearExpression::DebugString() const { return absl::StrCat( "BoundedLinearExpression(vars=[", absl::StrJoin(vars_, ", ", - [](std::string* out, std::shared_ptr var) { + [](std::string* out, std::shared_ptr var) { absl::StrAppend(out, var->DebugString()); }), "], coeffs=[", absl::StrJoin(coeffs_, ", "), "], offset=", offset_, diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h index 06d973f9ea..3e74f256a5 100644 --- a/ortools/sat/python/linear_expr.h +++ b/ortools/sat/python/linear_expr.h @@ -36,7 +36,7 @@ class FloatExprVisitor; class LinearExpr; class IntExprVisitor; class LinearExpr; -class BaseIntVar; +class IntVar; class NotBooleanVariable; /** @@ -152,9 +152,9 @@ class LinearExpr : public std::enable_shared_from_this { }; /// Compare the indices of variables. -struct BaseIntVarComparator { - bool operator()(std::shared_ptr lhs, - std::shared_ptr rhs) const; +struct IntVarComparator { + bool operator()(std::shared_ptr lhs, + std::shared_ptr rhs) const; }; /// A visitor class to process a floating point linear expression. @@ -162,15 +162,15 @@ class FloatExprVisitor { public: void AddToProcess(std::shared_ptr expr, double coeff); void AddConstant(double constant); - void AddVarCoeff(std::shared_ptr var, double coeff); + void AddVarCoeff(std::shared_ptr var, double coeff); void ProcessAll(); - double Process(std::vector>* vars, + double Process(std::vector>* vars, std::vector* coeffs); double Evaluate(const CpSolverResponse& solution); private: std::vector, double>> to_process_; - absl::btree_map, double, BaseIntVarComparator> + absl::btree_map, double, IntVarComparator> canonical_terms_; double offset_ = 0; }; @@ -188,7 +188,7 @@ class FlatFloatExpr : public LinearExpr { /// expression. explicit FlatFloatExpr(std::shared_ptr expr); /// Returns the array of variables of the flattened expression. - const std::vector>& vars() const { return vars_; } + const std::vector>& vars() const { return vars_; } /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const { return coeffs_; } /// Returns the offset of the flattened expression. @@ -202,7 +202,7 @@ class FlatFloatExpr : public LinearExpr { } private: - std::vector> vars_; + std::vector> vars_; std::vector coeffs_; double offset_ = 0; }; @@ -212,15 +212,15 @@ class IntExprVisitor { public: void AddToProcess(std::shared_ptr expr, int64_t coeff); void AddConstant(int64_t constant); - void AddVarCoeff(std::shared_ptr var, int64_t coeff); + void AddVarCoeff(std::shared_ptr var, int64_t coeff); bool ProcessAll(); - bool Process(std::vector>* vars, + bool Process(std::vector>* vars, std::vector* coeffs, int64_t* offset); bool Evaluate(const CpSolverResponse& solution, int64_t* value); private: std::vector, int64_t>> to_process_; - absl::btree_map, int64_t, BaseIntVarComparator> + absl::btree_map, int64_t, IntVarComparator> canonical_terms_; int64_t offset_ = 0; }; @@ -238,7 +238,7 @@ class FlatIntExpr : public LinearExpr { /// expression. explicit FlatIntExpr(std::shared_ptr expr); /// Returns the array of variables of the flattened expression. - const std::vector>& vars() const { return vars_; } + const std::vector>& vars() const { return vars_; } /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const { return coeffs_; } /// Returns the offset of the flattened expression. @@ -265,7 +265,7 @@ class FlatIntExpr : public LinearExpr { std::string DebugString() const override; private: - std::vector> vars_; + std::vector> vars_; std::vector coeffs_; int64_t offset_ = 0; bool ok_ = true; @@ -479,90 +479,115 @@ class Literal : public LinearExpr { * Returns: * The negation of the current literal. */ - virtual std::shared_ptr negated() = 0; + virtual std::shared_ptr negated() const = 0; /// Returns the hash of the current literal. int64_t Hash() const; }; /** - * A class to hold a variable index. It is the base class for Integer - * variables. + * An integer variable. + * + * An IntVar is an object that can take on any integer value within defined + * ranges. Variables appear in constraint like: + * + * x + y >= 5 + * AllDifferent([x, y, z]) + * + * Solving a model is equivalent to finding, for each variable, a single value + * from the set of initial values (called the initial domain), such that the + * model is feasible, or optimal if you provided an objective function. */ -class BaseIntVar : public Literal { +class IntVar : public Literal { public: - explicit BaseIntVar(int index) : index_(index), is_boolean_(false) { + IntVar(std::shared_ptr model, int index) + : model_proto_(model), index_(index) { DCHECK_GE(index, 0); } - BaseIntVar(int index, bool is_boolean); - ~BaseIntVar() override = default; + explicit IntVar(std::shared_ptr model) + : model_proto_(model), index_(model->variables_size()) { + model->add_variables(); + } + ~IntVar() override = default; + + /// Returns the index of the variable in the model. int index() const override { return index_; } + /// Returns the name of the variable. + std::string name() const; + + /// Overwrite the name of the variable. If name is empty, this method clears + /// the name of the variable. + void SetName(const std::string& name); + + /// Returns a copy of the domain of the variable. + Domain domain() const; + + /// Overwrite the domain of the variable. + void SetDomain(const Domain& domain); + + /// Returns the model proto. + std::shared_ptr model_proto() const; + + /// Returns the proto of the variable. + IntegerVariableProto* proto() const; + + /// Returns the negation of the current variable. + std::shared_ptr negated() const override; + + /// Returns true if the variable has a Boolean domain (0 or 1). + bool is_boolean() const; + + /// Returns true if the variable is fixed. + bool is_fixed() const; + bool VisitAsInt(IntExprVisitor& lin, int64_t c) override { - std::shared_ptr var = - std::static_pointer_cast(shared_from_this()); + std::shared_ptr var = + std::static_pointer_cast(shared_from_this()); lin.AddVarCoeff(var, c); return true; } void VisitAsFloat(FloatExprVisitor& lin, double c) override { - std::shared_ptr var = - std::static_pointer_cast(shared_from_this()); + std::shared_ptr var = + std::static_pointer_cast(shared_from_this()); lin.AddVarCoeff(var, c); } - std::string ToString() const override { - if (negated_ != nullptr) { - return absl::StrCat("BooleanBaseIntVar(", index_, ")"); - } else { - return absl::StrCat("BaseIntVar(", index_, ")"); - } - } + std::string ToString() const override; - std::string DebugString() const override { - return absl::StrCat("BaseIntVar(index=", index_, - ", is_boolean=", negated_ != nullptr, ")"); - } + std::string DebugString() const override; - /// Returns the negation of the current variable. - std::shared_ptr negated() override; + bool operator<(const IntVar& other) const { return index_ < other.index_; } - /// Returns true if the variable has a Boolean domain (0 or 1). - bool is_boolean() const { return is_boolean_; } - - bool operator<(const BaseIntVar& other) const { - return index_ < other.index_; - } - - protected: + private: + std::shared_ptr model_proto_; const int index_; - const bool is_boolean_; - std::shared_ptr negated_; }; template -H AbslHashValue(H h, std::shared_ptr i) { +H AbslHashValue(H h, std::shared_ptr i) { return H::combine(std::move(h), i->index()); } /// A class to hold a negated variable index. class NotBooleanVariable : public Literal { public: - explicit NotBooleanVariable(std::shared_ptr var) : var_(var) {} + explicit NotBooleanVariable(std::shared_ptr model_proto, + int var_index) + : model_proto_(model_proto), var_index_(var_index) {} ~NotBooleanVariable() override = default; /// Returns the index of the current literal. int index() const override; - bool ok() const { return !var_.expired(); } - /** * Returns the negation of the current literal, that is the original Boolean * variable. */ - std::shared_ptr negated() override; + std::shared_ptr negated() const override; bool VisitAsInt(IntExprVisitor& lin, int64_t c) override; @@ -573,11 +598,8 @@ class NotBooleanVariable : public Literal { std::string DebugString() const override; private: - // We keep a weak ptr to the base variable to avoid a circular dependency. - // The base variable holds a shared pointer to the negated variable. - // Any call to a risky method is checked at the pybind11 level to raise a - // python exception before the call is made. - std::weak_ptr var_; + std::shared_ptr model_proto_; + const int var_index_; }; /// A class to hold a linear expression with bounds. @@ -597,7 +619,7 @@ class BoundedLinearExpression { /// Returns the bounds constraining the expression passed to the constructor. const Domain& bounds() const; /// Returns the array of variables of the flattened expression. - const std::vector>& vars() const; + const std::vector>& vars() const; /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const; /// Returns the offset of the flattened expression. @@ -609,7 +631,7 @@ class BoundedLinearExpression { bool CastToBool(bool* result) const; private: - std::vector> vars_; + std::vector> vars_; std::vector coeffs_; int64_t offset_; const Domain bounds_; diff --git a/ortools/sat/python/linear_expr_doc.h b/ortools/sat/python/linear_expr_doc.h index d36484d457..b62752c217 100644 --- a/ortools/sat/python/linear_expr_doc.h +++ b/ortools/sat/python/linear_expr_doc.h @@ -46,55 +46,53 @@ static const char* __doc_operations_research_sat_python_AbslHashValue = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar = - R"doc(A class to hold a variable index. It is the base class for Integer -variables.)doc"; +static const char* __doc_operations_research_sat_python_IntVar = + R"doc(A class to hold an integer or Boolean variable)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_2 = - R"doc(A class to hold a variable index. It is the base class for Integer -variables.)doc"; +static const char* __doc_operations_research_sat_python_IntVar_2 = + R"doc(A class to hold an integer or Boolean variable)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVarComparator = +static const char* __doc_operations_research_sat_python_IntVarComparator = R"doc(Compare the indices of variables.)doc"; static const char* - __doc_operations_research_sat_python_BaseIntVarComparator_operator_call = + __doc_operations_research_sat_python_IntVarComparator_operator_call = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_BaseIntVar = +static const char* __doc_operations_research_sat_python_IntVar_IntVar = R"doc()doc"; -static const char* - __doc_operations_research_sat_python_BaseIntVar_BaseIntVar_2 = R"doc()doc"; - -static const char* __doc_operations_research_sat_python_BaseIntVar_DebugString = +static const char* __doc_operations_research_sat_python_IntVar_IntVar_2 = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_ToString = +static const char* __doc_operations_research_sat_python_IntVar_DebugString = R"doc()doc"; -static const char* - __doc_operations_research_sat_python_BaseIntVar_VisitAsFloat = R"doc()doc"; - -static const char* __doc_operations_research_sat_python_BaseIntVar_VisitAsInt = +static const char* __doc_operations_research_sat_python_IntVar_ToString = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_index = +static const char* __doc_operations_research_sat_python_IntVar_VisitAsFloat = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_index_2 = +static const char* __doc_operations_research_sat_python_IntVar_VisitAsInt = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_is_boolean = +static const char* __doc_operations_research_sat_python_IntVar_index = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntVar_index_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntVar_is_boolean = R"doc(Returns true if the variable has a Boolean domain (0 or 1).)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_negated = +static const char* __doc_operations_research_sat_python_IntVar_negated = R"doc(Returns the negation of the current variable.)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_negated_2 = +static const char* __doc_operations_research_sat_python_IntVar_negated_2 = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_operator_lt = +static const char* __doc_operations_research_sat_python_IntVar_operator_lt = R"doc()doc"; static const char* diff --git a/ortools/sat/python/sat_parameters_builder_test.py b/ortools/sat/python/sat_parameters_builder_test.py new file mode 100644 index 0000000000..2be6d40b58 --- /dev/null +++ b/ortools/sat/python/sat_parameters_builder_test.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2010-2025 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test sat parameters builder.""" + +from absl.testing import absltest +from ortools.sat.python import sat_parameters_builder + + +class SatParametersBuilderTest(absltest.TestCase): + + def test_basic_api(self) -> None: + params = sat_parameters_builder.SatParameters() + + # Test that we can set and get an integer parameter. + params.num_workers = 10 + self.assertEqual(params.num_workers, 10) + + # Test that we can set and get an enum parameter. + self.assertEqual( + params.clause_cleanup_ordering, + sat_parameters_builder.SatParameters.ClauseOrdering.CLAUSE_ACTIVITY, + ) + params.clause_cleanup_ordering = ( + sat_parameters_builder.SatParameters.ClauseOrdering.CLAUSE_LBD + ) + self.assertEqual( + params.clause_cleanup_ordering, + sat_parameters_builder.SatParameters.ClauseOrdering.CLAUSE_LBD, + ) + + # Test that we can set and get a repeated string parameter. + params.subsolvers.append("no_lp") + self.assertLen(params.subsolvers, 1) + self.assertEqual(params.subsolvers[0], "no_lp") + + +if __name__ == "__main__": + absltest.main() diff --git a/ortools/sat/python/wrappers.cc b/ortools/sat/python/wrappers.cc new file mode 100644 index 0000000000..b7bef4e925 --- /dev/null +++ b/ortools/sat/python/wrappers.cc @@ -0,0 +1,450 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/python/wrappers.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "google/protobuf/descriptor.h" + +namespace operations_research::sat::python { + +// A class that generates pybind11 code for a proto message. +class Generator { + public: + struct Context { + static Context TopLevel(const google::protobuf::Descriptor& msg) { + const std::string cpp_name = GetQualifiedCppName(msg); + const std::string shared_name = + absl::StrCat("std::shared_ptr<", cpp_name, ">"); + return {.cpp_name = cpp_name, .self_mutable_name = shared_name}; + } + + static Context Nested(const google::protobuf::Descriptor& msg) { + const std::string cpp_name = GetQualifiedCppName(msg); + return {.cpp_name = cpp_name, + .self_mutable_name = absl::StrCat(cpp_name, "*")}; + } + + std::string cpp_name; + std::string self_mutable_name; + }; + + explicit Generator( + absl::Span roots) + : message_stack_(roots.begin(), roots.end()) { + // DFS on root. + while (!message_stack_.empty()) { + const google::protobuf::Descriptor* const msg = message_stack_.back(); + message_stack_.pop_back(); + if (!visited_messages_.insert(msg).second) continue; + const bool is_top_level = absl::c_linear_search(roots, msg); + current_context_ = + is_top_level ? Context::TopLevel(*msg) : Context::Nested(*msg); + if (is_top_level) { + GenerateTopLevelMessageDecl(*msg); + } else { + GenerateMessageDecl(*msg); + } + GenerateMessageFields(*msg); + absl::StrAppend(&out_, ";\n"); + } + + // Now generate wrappers for enums, repeated and repeated ptr fields that + // were encountered along the way. + for (const google::protobuf::EnumDescriptor* pb_enum : enum_types_) { + GenerateEnumDecl(*pb_enum); + } + for (const google::protobuf::Descriptor* msg : repeated_ptr_types_) { + GenerateRepeatedPtrDecl(*msg); + } + for (const absl::string_view scalar_type : repeated_scalar_types_) { + GenerateRepeatedScalarDecl(scalar_type); + } + } + + std::string Result() && { return std::move(out_); } + + private: + template + static std::string GetQualifiedCppName(const DescriptorT& descriptor) { + return absl::StrReplaceAll(descriptor.full_name(), {{".", "::"}}); + } + + template + static std::string GetEscapedName(const DescriptorT& descriptor) { + return absl::StrReplaceAll(descriptor.full_name(), {{".", "_"}}); + } + + static std::string GetCppType( + const google::protobuf::FieldDescriptor::CppType cpp_type, + const google::protobuf::FieldDescriptor& field) { + switch (cpp_type) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return "int32_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return "int64_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return "uint32_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return "uint64_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return "double"; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return "float"; + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return "bool"; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + return GetQualifiedCppName(*field.enum_type()); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return "std::string"; + default: + LOG(FATAL) << "Unsupported type: " << cpp_type; + } + } + + // Generates a pybind11 wrapper class declaration for a top level message. + void GenerateTopLevelMessageDecl(const google::protobuf::Descriptor& msg) { + CHECK(wrapper_id_.emplace(&msg, wrapper_id_.size()).second) + << "duplicate message: " << msg.full_name(); + absl::SubstituteAndAppend(&out_, R"( + const auto $0 = py::class_<$1, std::shared_ptr<$1>>($2, "$3"))", + GetWrapperName(&msg), current_context_.cpp_name, + GetWrapperName(msg.containing_type()), + msg.name()); + // Add constructor and utilities. + absl::SubstituteAndAppend(&out_, R"( + .def(py::init<>()) + .def("copy_from", + [](std::shared_ptr<$0> self, std::shared_ptr<$0> other) { + self->CopyFrom(*other); + }) + .def("merge_from", + [](std::shared_ptr<$0> self, std::shared_ptr<$0> other) { + self->MergeFrom(*other); + }) + .def("merge_text_format", + [](std::shared_ptr<$0> self, const std::string& text) { + return google::protobuf::TextFormat::MergeFromString(text, self.get()); + }) + .def("parse_text_format", + [](std::shared_ptr<$0> self, const std::string& text) { + return google::protobuf::TextFormat::ParseFromString(text, self.get()); + }) + .def("__copy__", + [](std::shared_ptr<$0> self) { + return self; + }) + .def("__deepcopy__", + [](std::shared_ptr<$0> self, py::dict) { + std::shared_ptr<$0> result = std::make_shared<$0>(); + result->CopyFrom(*self); + return result; + }) + .def("__str__", + [](std::shared_ptr<$0> self) { + return operations_research::ProtobufDebugString(*self); + }))", + current_context_.cpp_name); + } + + // Generates a pybind11 wrapper class declaration for a message. + void GenerateMessageDecl(const google::protobuf::Descriptor& msg) { + CHECK(wrapper_id_.emplace(&msg, wrapper_id_.size()).second) + << "duplicate message: " << msg.full_name(); + absl::SubstituteAndAppend(&out_, R"( + const auto $0 = py::class_<$1>($2, "$3"))", + GetWrapperName(&msg), current_context_.cpp_name, + GetWrapperName(msg.containing_type()), + msg.name()); + // Add constructor and utilities. + absl::SubstituteAndAppend(&out_, R"( + .def(py::init<>()) + .def("copy_from", + []($0* self, const $0& other) { self->CopyFrom(other); }) + .def("merge_from", + []($0* self, const $0& other) { self->MergeFrom(other); }) + .def("merge_text_format", + []($0* self, const std::string& text) { + return google::protobuf::TextFormat::MergeFromString(text, self); + }) + .def("parse_text_format", + []($0* self, const std::string& text) { + return google::protobuf::TextFormat::ParseFromString(text, self); + }) + .def("__copy__", + []($0 self) { + return $0(self); + }) + .def("__deepcopy__", + []($0 self, py::dict) { + return $0(self); + }) + .def("__str__", + []($0 self) { + return operations_research::ProtobufDebugString(self); + }))", + current_context_.cpp_name); + } + + // Generates a pybind11 wrapper class declaration for an enum. + void GenerateEnumDecl(const google::protobuf::EnumDescriptor& pb_enum) { + absl::SubstituteAndAppend(&out_, R"( + py::enum_<$0>($1, "$2"))", + GetQualifiedCppName(pb_enum), + GetWrapperName(pb_enum.containing_type()), + pb_enum.name()); + for (int i = 0; i < pb_enum.value_count(); ++i) { + const google::protobuf::EnumValueDescriptor& value = *pb_enum.value(i); + absl::SubstituteAndAppend(&out_, R"( + .value("$0", $1))", + value.name(), GetQualifiedCppName(value)); + } + absl::SubstituteAndAppend(&out_, R"( + .export_values();)"); + } + + // Generates a pybind11 wrapper class declaration & definitions for a repeated + // ptr. + void GenerateRepeatedPtrDecl(const google::protobuf::Descriptor& msg) { + absl::SubstituteAndAppend(&out_, R"( + py::class_>(py_module, "repeated_$1") + .def("add", + [](google::protobuf::RepeatedPtrField<$0>* self) { + return self->Add(); + }, + py::return_value_policy::reference, py::keep_alive<0, 1>()) + .def("append", [](google::protobuf::RepeatedPtrField<$0>* self, const $0& value) { + *self->Add() = value; + }) + .def("extend", + [](google::protobuf::RepeatedPtrField<$0>* self, const std::vector<$0>& values) { + for (const $0& value : values) { + *self->Add() = value; + } + }) + .def("__len__", &google::protobuf::RepeatedPtrField<$0>::size) + .def("__getitem__", + [](google::protobuf::RepeatedPtrField<$0>* self, int index) { + if (index >= self->size()) { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + return self->Mutable(index); + }, + py::return_value_policy::reference, py::keep_alive<0, 1>());)", + GetQualifiedCppName(msg), msg.name()); + } + + // Generates a pybind11 wrapper class declaration & definitions for a repeated + // scalar. + void GenerateRepeatedScalarDecl(absl::string_view scalar_type) { + if (scalar_type == "std::string") { + absl::StrAppend(&out_, R"( + py::class_>(py_module, "repeated_scalar_std_string") + .def("append", + [](google::protobuf::RepeatedPtrField* self, std::string str) { + self->Add(std::move(str)); + }) + .def("extend", + [](google::protobuf::RepeatedPtrField* self, + const std::vector& values) { + self->Add(values.begin(), values.end()); + }) + .def("__len__", [](const google::protobuf::RepeatedPtrField& self) { + return self.size(); + }) + .def("__getitem__", + [](const google::protobuf::RepeatedPtrField& self, int index) { + if (index >= self.size()) { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + + return self.Get(index); + }, + py::return_value_policy::copy) + .def("__setitem__", + [](google::protobuf::RepeatedPtrField* self, + int index, const std::string& value) { + self->at(index) = value; + }) + .def("__str__", [](const google::protobuf::RepeatedPtrField& self) { + return absl::StrCat("[", absl::StrJoin(self, ", "), "]"); + });)"); + } else { + absl::SubstituteAndAppend( + &out_, R"( + py::class_>(py_module, "repeated_scalar_$1") + .def("append", [](google::protobuf::RepeatedField<$0>* self, $0 value) { + self->Add(value); + }) + .def("extend", [](google::protobuf::RepeatedField<$0>* self, + const std::vector<$0>& values) { + self->Add(values.begin(), values.end()); + }) + .def("__len__", [](const google::protobuf::RepeatedField<$0>& self) { + return self.size(); + }) + .def("__getitem__", [](const google::protobuf::RepeatedField<$0>& self, int index) { + if (index >= self.size()) { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + + return self.Get(index); + }) + .def("__setitem__", &google::protobuf::RepeatedField<$0>::Set) + .def("__str__", [](const google::protobuf::RepeatedField<$0>& self) { + return absl::StrCat("[", absl::StrJoin(self, ", "), "]"); + });)", + scalar_type, absl::StrReplaceAll(scalar_type, {{"::", "_"}})); + } + } + + void GenerateRepeatedField(const google::protobuf::FieldDescriptor& field) { + const google::protobuf::Descriptor* msg_type = field.message_type(); + if (msg_type != nullptr) { + // Repeated message. + absl::SubstituteAndAppend( + &out_, R"( + .def_property_readonly( + "$0", + []($1 self) { return self->mutable_$2(); }, + py::return_value_policy::reference, py::keep_alive<0, 1>()))", + field.name(), current_context_.self_mutable_name, field.name()); + // We'll need to generate the wrapping for `proto2::RepeatedPtrField<$3>`. + repeated_ptr_types_.insert(msg_type); + // We'll need to generate the wrapping for this message type. + message_stack_.push_back(ABSL_DIE_IF_NULL(field.message_type())); + } else { + // Repeated scalar field. + absl::SubstituteAndAppend(&out_, R"( + .def_property_readonly( + "$0", + []($1 self) { return self->mutable_$0(); }, + py::return_value_policy::reference, py::keep_alive<0, 1>()))", + field.name(), + current_context_.self_mutable_name); + // We'll need to generate the wrapping for `proto2::RepeatedField<$2>`. + repeated_scalar_types_.insert(GetCppType(field.cpp_type(), field)); + } + } + + void GenerateSingularField(const google::protobuf::FieldDescriptor& field) { + if (const google::protobuf::Descriptor* msg_type = field.message_type()) { + // Singular message. + absl::SubstituteAndAppend(&out_, R"( + .def_property( + "$0", + []($1 self) { return self->mutable_$0(); }, + []($1 self, $2 arg) { *self->mutable_$0() = arg; }, + py::return_value_policy::reference_internal) + .def("clear_$0", []($1 self) { self->clear_$0(); }) + .def("has_$0", []($1 self) { return self->has_$0(); }))", + field.name(), + current_context_.self_mutable_name, + GetQualifiedCppName(*msg_type)); + // We'll need to generate the wrapping for this message type. + message_stack_.push_back(ABSL_DIE_IF_NULL(field.message_type())); + } else { + if (const google::protobuf::EnumDescriptor* enum_type = + field.enum_type()) { + enum_types_.insert(enum_type); + } + // Singular scalar (int, string, ...). + absl::SubstituteAndAppend(&out_, R"( + .def_property( + "$0", + []($1 msg) { return msg->$0(); }, + []($1 msg, $2 arg) { return msg->set_$0(arg); }) + .def("clear_$0", []($1 self) { self->clear_$0(); }))", + field.name(), + current_context_.self_mutable_name, + GetCppType(field.cpp_type(), field)); + } + } + + // Generates definitions for accessing fields of a message. + void GenerateMessageFields(const google::protobuf::Descriptor& msg) { + const std::string msg_name = GetQualifiedCppName(msg); + + for (int i = 0; i < msg.field_count(); ++i) { + const google::protobuf::FieldDescriptor& field = + *ABSL_DIE_IF_NULL(msg.field(i)); + if (field.is_repeated()) { + GenerateRepeatedField(field); + } else { + GenerateSingularField(field); + } + } + } + + // Returns the wrapper name for a message (or "py_module" if `msg` is null). + // Dies if the scope is not found. + std::string GetWrapperName(const google::protobuf::Descriptor* msg) { + const auto it = wrapper_id_.find(msg); + CHECK(it != wrapper_id_.end()) + << "wrapper id not found: " << msg->full_name(); + if (msg == nullptr) return "py_module"; + return absl::StrCat("gen_", it->second); + } + + // This identifies the pybind11 wrapper variable for a `_class` declaration in + // the generated code. These are used to generate enums in the correct + // scope. + static constexpr int kNoScope = 0; + absl::flat_hash_map wrapper_id_ = { + {nullptr, kNoScope}}; + + // Output buffer. + std::string out_; + + // Our DFS stack. + std::vector message_stack_; + absl::flat_hash_set + visited_messages_; + + // A list of enum wrappers to generate. + absl::flat_hash_set + enum_types_; + // A list of repeated ptr wrappers to generate. + absl::flat_hash_set + repeated_ptr_types_; + // A list of repeated scalar wrappers to generate. + absl::flat_hash_set repeated_scalar_types_; + + // Context for the current message being generated. + Context current_context_; +}; + +std::string GeneratePybindCode( + absl::Span roots) { + return Generator(roots).Result(); +} + +} // namespace operations_research::sat::python diff --git a/ortools/sat/python/wrappers.h b/ortools/sat/python/wrappers.h new file mode 100644 index 0000000000..04aaa6e594 --- /dev/null +++ b/ortools/sat/python/wrappers.h @@ -0,0 +1,31 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_PYTHON_WRAPPERS_H_ +#define OR_TOOLS_SAT_PYTHON_WRAPPERS_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/types/span.h" +#include "google/protobuf/descriptor.h" + +namespace operations_research::sat::python { + +// Generated pybind11 code for the given proto messages. +std::string GeneratePybindCode( + absl::Span roots); + +} // namespace operations_research::sat::python + +#endif // OR_TOOLS_SAT_PYTHON_WRAPPERS_H_ diff --git a/ortools/sat/samples/assumptions_sample_sat.go b/ortools/sat/samples/assumptions_sample_sat.go index d4564dbd48..f2603241d6 100644 --- a/ortools/sat/samples/assumptions_sample_sat.go +++ b/ortools/sat/samples/assumptions_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/boolean_product_sample_sat.go b/ortools/sat/samples/boolean_product_sample_sat.go index ad93ecde04..cb0185be42 100644 --- a/ortools/sat/samples/boolean_product_sample_sat.go +++ b/ortools/sat/samples/boolean_product_sample_sat.go @@ -19,8 +19,9 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func booleanProductSample() error { diff --git a/ortools/sat/samples/channeling_sample_sat.go b/ortools/sat/samples/channeling_sample_sat.go index 9ce0bfa0d4..a35c6f2337 100644 --- a/ortools/sat/samples/channeling_sample_sat.go +++ b/ortools/sat/samples/channeling_sample_sat.go @@ -19,9 +19,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func channelingSampleSat() error { diff --git a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go index 651f72c849..cc7ca29fc1 100644 --- a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go +++ b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go @@ -20,9 +20,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) const ( diff --git a/ortools/sat/samples/no_overlap_sample_sat.go b/ortools/sat/samples/no_overlap_sample_sat.go index 8b94881aec..e24e7775e8 100644 --- a/ortools/sat/samples/no_overlap_sample_sat.go +++ b/ortools/sat/samples/no_overlap_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/rabbits_and_pheasants_sat.go b/ortools/sat/samples/rabbits_and_pheasants_sat.go index 1a8cad267c..3dc6a51190 100644 --- a/ortools/sat/samples/rabbits_and_pheasants_sat.go +++ b/ortools/sat/samples/rabbits_and_pheasants_sat.go @@ -20,6 +20,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/ranking_sample_sat.go b/ortools/sat/samples/ranking_sample_sat.go index 779d477040..838d8c6e80 100644 --- a/ortools/sat/samples/ranking_sample_sat.go +++ b/ortools/sat/samples/ranking_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/search_for_all_solutions_sample_sat.go b/ortools/sat/samples/search_for_all_solutions_sample_sat.go index 31a1fbd98f..15ba8ec56d 100644 --- a/ortools/sat/samples/search_for_all_solutions_sample_sat.go +++ b/ortools/sat/samples/search_for_all_solutions_sample_sat.go @@ -20,8 +20,9 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func searchForAllSolutionsSampleSat() error { diff --git a/ortools/sat/samples/simple_sat_program.go b/ortools/sat/samples/simple_sat_program.go index 47c151adcf..588ceed1ea 100644 --- a/ortools/sat/samples/simple_sat_program.go +++ b/ortools/sat/samples/simple_sat_program.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/solution_hinting_sample_sat.go b/ortools/sat/samples/solution_hinting_sample_sat.go index 1b7a31f4ee..8ad6434151 100644 --- a/ortools/sat/samples/solution_hinting_sample_sat.go +++ b/ortools/sat/samples/solution_hinting_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go index 573b6b2c91..46b85cb548 100644 --- a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go +++ b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go @@ -19,8 +19,9 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func solveAndPrintIntermediateSolutionsSampleSat() error { diff --git a/ortools/sat/samples/solve_with_time_limit_sample_sat.go b/ortools/sat/samples/solve_with_time_limit_sample_sat.go index a600391017..c7b89e8d51 100644 --- a/ortools/sat/samples/solve_with_time_limit_sample_sat.go +++ b/ortools/sat/samples/solve_with_time_limit_sample_sat.go @@ -19,9 +19,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func solveWithTimeLimitSampleSat() error { diff --git a/ortools/sat/samples/step_function_sample_sat.go b/ortools/sat/samples/step_function_sample_sat.go index 21b9e1f044..7fa569d7da 100644 --- a/ortools/sat/samples/step_function_sample_sat.go +++ b/ortools/sat/samples/step_function_sample_sat.go @@ -19,9 +19,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func stepFunctionSampleSat() error { diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index ab7d851a1d..f6502cdfdb 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -24,7 +24,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 328 +// NEXT TAG: 329 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -1277,6 +1277,11 @@ message SatParameters { // SPLIT_STRATEGY_BALANCED_TREE and SPLIT_STRATEGY_DISCREPANCY. optional int32 shared_tree_balance_tolerance = 305 [default = 1]; + // How much dtime a worker will wait between proposing splits. + // This limits the contention in splitting the shared tree, and also reduces + // the number of too-easy subtrees that are generates. + optional double shared_tree_split_min_dtime = 328 [default = 0.1]; + // Whether we enumerate all solutions of a problem without objective. Note // that setting this to true automatically disable some presolve reduction // that can remove feasible solution. That is it has the same effect as diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index f09a824c4d..6b4344cc08 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -280,9 +280,9 @@ std::string ProgressMessage(absl::string_view event_or_solution_count, obj_next, solution_info); } -std::string SatProgressMessage(const std::string& event_or_solution_count, +std::string SatProgressMessage(absl::string_view event_or_solution_count, double time_in_seconds, - const std::string& solution_info) { + absl::string_view solution_info) { return absl::StrFormat("#%-5s %6.2fs %s", event_or_solution_count, time_in_seconds, solution_info); } @@ -766,7 +766,7 @@ void SharedResponseManager::FillObjectiveValuesInResponse( std::shared_ptr::Solution> SharedResponseManager::NewSolution(absl::Span solution_values, - const std::string& solution_info, + absl::string_view solution_info, Model* model, int source_id) { absl::MutexLock mutex_lock(&mutex_); std::shared_ptr::Solution> ret; @@ -849,7 +849,7 @@ SharedResponseManager::NewSolution(absl::Span solution_values, } if (logger_->LoggingIsEnabled()) { - std::string solution_message = solution_info; + std::string solution_message(solution_info); if (tmp_postsolved_response.num_booleans() > 0) { absl::StrAppend(&solution_message, " (fixed_bools=", tmp_postsolved_response.num_fixed_booleans(), "/", diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index 83085ff27e..c6babb7d5f 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -476,7 +476,7 @@ class SharedResponseManager { // stored in the repository. std::shared_ptr::Solution> NewSolution(absl::Span solution_values, - const std::string& solution_info, Model* model = nullptr, + absl::string_view solution_info, Model* model = nullptr, int source_id = -1); // Changes the solution to reflect the fact that the "improving" problem is diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index cec74c9b52..11bb2324bb 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -50,11 +50,9 @@ namespace operations_research::sat { namespace { - -// We restart the shared tree 10 times after 2 restarts per worker. After that -// we restart when the tree reaches the maximum allowable number of nodes, but -// still at most once per 2 restarts per worker. -const int kSyncsPerWorkerPerRestart = 2; +// We restart the shared tree 10 times after (on average) 2 tree assignments per +// worker. +const int kAssignmentsPerWorkerPerRestart = 2; const int kNumInitialRestarts = 10; // If you build a tree by expanding the nodes with minimal depth+discrepancy, @@ -233,7 +231,6 @@ SharedTreeManager::SharedTreeManager(Model* model) {.literal = ProtoLiteral(), .objective_lb = shared_response_manager_->GetInnerObjectiveLowerBound(), .trail_info = std::make_unique()}); - unassigned_leaves_.reserve(num_workers_); unassigned_leaves_.push_back(&nodes_.back()); } @@ -279,7 +276,10 @@ bool SharedTreeManager::SyncTree(ProtoTrail& path) { return false; } // Restart after processing updates - we might learn a new objective bound. - if (++num_syncs_since_restart_ / num_workers_ > kSyncsPerWorkerPerRestart && + // Do initial restarts once the tree has been split a reasonable number of + // times. + if (num_leaves_assigned_since_restart_ > + kAssignmentsPerWorkerPerRestart * num_workers_ && num_restarts_ < kNumInitialRestarts) { RestartLockHeld(); path.Clear(); @@ -371,11 +371,10 @@ void SharedTreeManager::ReplaceTree(ProtoTrail& path) { } path.Clear(); while (!unassigned_leaves_.empty()) { - const int i = num_leaves_assigned_++ % unassigned_leaves_.size(); - std::swap(unassigned_leaves_[i], unassigned_leaves_.back()); - Node* leaf = unassigned_leaves_.back(); - unassigned_leaves_.pop_back(); + Node* leaf = unassigned_leaves_.front(); + unassigned_leaves_.pop_front(); if (!leaf->closed && leaf->children[0] == nullptr) { + num_leaves_assigned_since_restart_ += 1; AssignLeaf(path, leaf); path.SetTargetPhase(GetTrailInfo(leaf)->phase); return; @@ -471,8 +470,7 @@ void SharedTreeManager::ProcessNodeChanges() { } if (num_newly_closed > 0) { shared_response_manager_->LogMessageWithThrottling( - "Tree", absl::StrCat("nodes:", nodes_.size(), "/", max_nodes_, - " closed:", num_closed_nodes_, + "Tree", absl::StrCat("closed:", num_closed_nodes_, "/", nodes_.size(), " unassigned:", unassigned_leaves_.size(), " restarts:", num_restarts_)); } @@ -582,7 +580,7 @@ void SharedTreeManager::RestartLockHeld() { num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1; num_closed_nodes_ = 0; num_restarts_ += 1; - num_syncs_since_restart_ = 0; + num_leaves_assigned_since_restart_ = 0; } std::string SharedTreeManager::ShortStatus() const { @@ -729,8 +727,9 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) { const auto& decision_policy = heuristics_->decision_policies[heuristics_->policy_index]; const int next_level = sat_solver_->CurrentDecisionLevel() + 1; - new_split_available_ = next_level == assigned_tree_.MaxLevel() + 1; - + if (next_level == assigned_tree_.MaxLevel() + 1) { + new_split_available_ = true; + } CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); if (next_level <= assigned_tree_.MaxLevel()) { VLOG(2) << "Following shared trail depth=" << next_level << " " @@ -747,7 +746,8 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) { void SharedTreeWorker::MaybeProposeSplit() { if (!new_split_available_ || - sat_solver_->CurrentDecisionLevel() != assigned_tree_.MaxLevel() + 1) { + sat_solver_->CurrentDecisionLevel() < assigned_tree_.MaxLevel() + 1 || + time_limit_->GetElapsedDeterministicTime() < next_split_dtime_) { return; } new_split_available_ = false; @@ -755,6 +755,8 @@ void SharedTreeWorker::MaybeProposeSplit() { sat_solver_->Decisions()[assigned_tree_.MaxLevel()].literal; const std::optional encoded = EncodeDecision(split_decision); if (encoded.has_value()) { + next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() + + parameters_->shared_tree_split_min_dtime(); CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); manager_->ProposeSplit(assigned_tree_, *encoded); if (assigned_tree_.MaxLevel() > assigned_tree_literals_.size()) { diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index 980db80226..d2f7463f8b 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -281,7 +281,7 @@ class SharedTreeManager { // Stores the nodes in the search tree. std::deque nodes_ ABSL_GUARDED_BY(mu_); - std::vector unassigned_leaves_ ABSL_GUARDED_BY(mu_); + std::deque unassigned_leaves_ ABSL_GUARDED_BY(mu_); // How many splits we should generate now to keep the desired number of // leaves. @@ -291,7 +291,7 @@ class SharedTreeManager { // communication overhead. If we exceed this, workers become portfolio // workers when no unassigned leaves are available. const int max_nodes_; - int num_leaves_assigned_ ABSL_GUARDED_BY(mu_) = 0; + int num_leaves_assigned_since_restart_ ABSL_GUARDED_BY(mu_) = 0; // Temporary vectors used to maintain the state of the tree when nodes are // closed and/or children are updated. @@ -299,7 +299,6 @@ class SharedTreeManager { std::vector to_update_ ABSL_GUARDED_BY(mu_); int64_t num_restarts_ ABSL_GUARDED_BY(mu_) = 0; - int64_t num_syncs_since_restart_ ABSL_GUARDED_BY(mu_) = 0; int num_closed_nodes_ ABSL_GUARDED_BY(mu_) = 0; }; @@ -359,6 +358,7 @@ class SharedTreeWorker { ProtoTrail assigned_tree_; std::vector assigned_tree_literals_; std::vector> assigned_tree_implications_; + double next_split_dtime_ = 0; // True if the last decision may split the assigned tree and has not yet been // proposed to the SharedTreeManager. diff --git a/ortools/sat/work_assignment_test.cc b/ortools/sat/work_assignment_test.cc index 29cc768b4d..90a8c287d4 100644 --- a/ortools/sat/work_assignment_test.cc +++ b/ortools/sat/work_assignment_test.cc @@ -576,6 +576,9 @@ TEST(SharedTreeManagerTest, TrailSharing) { shared_tree_manager->ReplaceTree(trail1); shared_tree_manager->ReplaceTree(trail2); + EXPECT_EQ(shared_tree_manager->NumNodes(), 3); + EXPECT_EQ(trail1.MaxLevel(), 1); + EXPECT_EQ(trail2.MaxLevel(), 1); EXPECT_EQ(trail2.Implications(1).size(), 1); EXPECT_EQ(trail2.TargetPhase().size(), 1); EXPECT_TRUE(trail1.Implications(1).empty());