diff --git a/examples/python/bus_driver_scheduling_sat.py b/examples/python/bus_driver_scheduling_sat.py index 2032c223d5..6968f1d558 100644 --- a/examples/python/bus_driver_scheduling_sat.py +++ b/examples/python/bus_driver_scheduling_sat.py @@ -164,6 +164,10 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): delay_literals = [] delay_weights = [] + # Used to propagate more between drivers + shared_incoming_literals = collections.defaultdict(list) + shared_outgoing_literals = collections.defaultdict(list) + for d in range(num_drivers): start_times.append( model.NewIntVar(min_start_time - setup_time, max_end_time, @@ -200,6 +204,7 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): source_lit = model.NewBoolVar('%i from source to %i' % (d, s)) outgoing_source_literals.append(source_lit) incoming_literals[s].append(source_lit) + shared_incoming_literals[s].append(source_lit) model.Add(start_times[d] == shift[3] - setup_time).OnlyEnforceIf(source_lit) model.Add( @@ -213,6 +218,7 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): # - set the driving times of the driver sink_lit = model.NewBoolVar('%i from %i to sink' % (d, s)) outgoing_literals[s].append(sink_lit) + shared_outgoing_literals[s].append(sink_lit) incoming_sink_literals.append(sink_lit) model.Add(end_times[d] == shift[4] + cleanup_time).OnlyEnforceIf(sink_lit) @@ -228,6 +234,8 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): performed[d, s].Not()) incoming_literals[s].append(performed[d, s].Not()) outgoing_literals[s].append(performed[d, s].Not()) + # Not adding to the shared lists, because, globally, each node will have + # one incoming literal, and one outgoing literal. # Node performed: # - add upper bound on start_time @@ -259,7 +267,9 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): # Add arc outgoing_literals[s].append(lit) + shared_outgoing_literals[s].append(lit) incoming_literals[o].append(lit) + shared_incoming_literals[o].append(lit) # Cost part delay_literals.append(lit) @@ -296,6 +306,9 @@ def bus_driver_scheduling(minimize_drivers, max_num_drivers): # Each shift is covered. for s in range(num_shifts): model.Add(sum(performed[d, s] for d in range(num_drivers)) == 1) + # Globally, each node has one incoming and one outgoing literal + model.Add(sum(shared_incoming_literals[s]) == 1) + model.Add(sum(shared_outgoing_literals[s]) == 1) # Symmetry breaking