Files
ortools-clone/examples/python/balance_group_sat.py

186 lines
6.0 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
2025-01-10 11:35:44 +01:00
# Copyright 2010-2025 Google LLC
2018-07-23 14:59:25 -07:00
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2023-07-01 06:06:53 +02:00
2018-09-19 13:21:13 +02:00
"""We are trying to group items in equal sized groups.
Each item has a color and a value. We want the sum of values of each group to
be as close to the average as possible.
Furthermore, if one color is an a group, at least k items with this color must
be in that group.
"""
2025-07-23 17:38:49 +02:00
2024-07-24 14:54:57 -07:00
from typing import Dict, Sequence
from absl import app
2024-07-24 14:54:57 -07:00
2018-08-07 15:16:08 -07:00
from ortools.sat.python import cp_model
2018-07-23 15:29:39 -07:00
2018-07-23 14:59:25 -07:00
# Create a solution printer.
class SolutionPrinter(cp_model.CpSolverSolutionCallback):
"""Print intermediate solutions."""
def __init__(self, values, colors, all_groups, all_items, item_in_group):
cp_model.CpSolverSolutionCallback.__init__(self)
self.__solution_count = 0
self.__values = values
self.__colors = colors
self.__all_groups = all_groups
self.__all_items = all_items
self.__item_in_group = item_in_group
2018-11-20 05:44:21 -08:00
def on_solution_callback(self):
print(f"Solution {self.__solution_count}")
self.__solution_count += 1
print(f" objective value = {self.objective_value}")
groups = {}
sums = {}
for g in self.__all_groups:
groups[g] = []
sums[g] = 0
for item in self.__all_items:
2023-11-16 19:46:56 +01:00
if self.boolean_value(self.__item_in_group[(item, g)]):
groups[g].append(item)
sums[g] += self.__values[item]
for g in self.__all_groups:
group = groups[g]
print(f"group {g}: sum = {sums[g]:0.2f} [", end="")
for item in group:
value = self.__values[item]
color = self.__colors[item]
print(f" ({item}, {value}, {color})", end="")
2023-07-01 06:06:53 +02:00
print("]")
2018-08-07 15:16:08 -07:00
def main(argv: Sequence[str]) -> None:
"""Solves a group balancing problem."""
if len(argv) > 1:
2023-07-01 06:06:53 +02:00
raise app.UsageError("Too many command-line arguments.")
# Data.
num_groups = 10
num_items = 100
num_colors = 3
min_items_of_same_color_per_group = 4
2018-08-07 15:16:08 -07:00
all_groups = range(num_groups)
all_items = range(num_items)
all_colors = range(num_colors)
2018-08-07 15:16:08 -07:00
2023-11-16 19:46:56 +01:00
# values for each items.
values = [1 + i + (i * i // 200) for i in all_items]
# Color for each item (simple modulo).
colors = [i % num_colors for i in all_items]
2018-08-07 15:16:08 -07:00
sum_of_values = sum(values)
average_sum_per_group = sum_of_values // num_groups
2018-08-07 15:16:08 -07:00
num_items_per_group = num_items // num_groups
# Collect all items in a given color.
2024-07-24 14:54:57 -07:00
items_per_color: Dict[int, list[int]] = {}
for color in all_colors:
items_per_color[color] = []
for i in all_items:
2024-07-24 14:54:57 -07:00
if colors[i] == color:
items_per_color[color].append(i)
print(
f"Model has {num_items} items, {num_groups} groups, and" f" {num_colors} colors"
)
2024-07-24 14:54:57 -07:00
print(f" average sum per group = {average_sum_per_group}")
# Model.
model = cp_model.CpModel()
item_in_group = {}
for i in all_items:
for g in all_groups:
item_in_group[(i, g)] = model.new_bool_var(f"item {i} in group {g}")
2018-08-07 15:16:08 -07:00
# Each group must have the same size.
2018-08-07 15:16:08 -07:00
for g in all_groups:
2023-11-16 19:46:56 +01:00
model.add(sum(item_in_group[(i, g)] for i in all_items) == num_items_per_group)
# One item must belong to exactly one group.
for i in all_items:
2023-11-16 19:46:56 +01:00
model.add(sum(item_in_group[(i, g)] for g in all_groups) == 1)
# The deviation of the sum of each items in a group against the average.
2023-11-16 19:46:56 +01:00
e = model.new_int_var(0, 550, "epsilon")
2018-08-07 15:16:08 -07:00
# Constrain the sum of values in one group around the average sum per group.
2018-08-07 15:16:08 -07:00
for g in all_groups:
2023-11-16 19:46:56 +01:00
model.add(
2023-07-01 06:06:53 +02:00
sum(item_in_group[(i, g)] * values[i] for i in all_items)
<= average_sum_per_group + e
)
2023-11-16 19:46:56 +01:00
model.add(
2023-07-01 06:06:53 +02:00
sum(item_in_group[(i, g)] * values[i] for i in all_items)
>= average_sum_per_group - e
)
# color_in_group variables.
color_in_group = {}
for g in all_groups:
for c in all_colors:
color_in_group[(c, g)] = model.new_bool_var(f"color {c} is in group {g}")
2018-08-07 15:16:08 -07:00
# Item is in a group implies its color is in that group.
for i in all_items:
for g in all_groups:
2023-11-16 19:46:56 +01:00
model.add_implication(item_in_group[(i, g)], color_in_group[(colors[i], g)])
2018-08-07 15:16:08 -07:00
# If a color is in a group, it must contains at least
# min_items_of_same_color_per_group items from that color.
for c in all_colors:
for g in all_groups:
literal = color_in_group[(c, g)]
2023-11-16 19:46:56 +01:00
model.add(
2023-07-01 06:06:53 +02:00
sum(item_in_group[(i, g)] for i in items_per_color[c])
>= min_items_of_same_color_per_group
2023-11-16 19:46:56 +01:00
).only_enforce_if(literal)
# Compute the maximum number of colors in a group.
max_color = num_items_per_group // min_items_of_same_color_per_group
# Redundant constraint, it helps with solving time.
if max_color < num_colors:
for g in all_groups:
2023-11-16 19:46:56 +01:00
model.add(sum(color_in_group[(c, g)] for c in all_colors) <= max_color)
2023-11-16 19:46:56 +01:00
# minimize epsilon
model.minimize(e)
solver = cp_model.CpSolver()
# solver.parameters.log_search_progress = True
solver.parameters.num_workers = 16
2023-07-01 06:06:53 +02:00
solution_printer = SolutionPrinter(
values, colors, all_groups, all_items, item_in_group
)
2023-11-16 19:46:56 +01:00
status = solver.solve(model, solution_printer)
if status == cp_model.OPTIMAL:
print(f"Optimal epsilon: {solver.objective_value}")
print(solver.response_stats())
else:
2023-07-01 06:06:53 +02:00
print("No solution found")
2023-07-01 06:06:53 +02:00
if __name__ == "__main__":
app.run(main)