Files
ortools-clone/examples/notebook/balance_group_sat.ipynb
2018-08-02 12:08:05 -07:00

90 lines
3.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Copyright 2010-2017 Google\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"from __future__ import print_function\n",
"\n",
"from ortools.sat.python import cp_model\n",
"\n",
"num_groups = 11\n",
"num_values = 121\n",
"\n",
"\n",
"model = cp_model.CpModel()\n",
"\n",
"boo_x_i_j = {}\n",
"for i in range(num_values):\n",
" for j in range(num_groups):\n",
" boo_x_i_j[(i, j)] = model.NewBoolVar('x%d belongs to group %d' % (i, j))\n",
"\n",
"e = model.NewIntVar(0, 5, 'epsilon')\n",
"\n",
"values = [i + 1 + 3 * (i > 99) for i in range(num_values)]\n",
"sum_of_values = sum(values)\n",
"average_value = sum_of_values / num_groups\n",
"\n",
"for j in range(num_groups):\n",
" model.Add(sum(boo_x_i_j[(i, j)]\n",
" for i in range(num_values)) == num_values / num_groups)\n",
"\n",
"for i in range(num_values):\n",
" model.Add(sum(boo_x_i_j[(i, j)] for j in range(num_groups)) == 1)\n",
"\n",
"for j in range(num_groups):\n",
" model.Add(sum(boo_x_i_j[(i, j)] * values[i] for i in range(num_values)) -\n",
" average_value <= e)\n",
" model.Add(sum(boo_x_i_j[(i, j)] * values[i] for i in range(num_values)) -\n",
" average_value >= -e)\n",
"\n",
"model.Minimize(e)\n",
"\n",
"\n",
"solver = cp_model.CpSolver()\n",
"status = solver.Solve(model)\n",
"print('Optimal epsilon: %i' % solver.ObjectiveValue())\n",
"print('Statistics')\n",
"print(' - conflicts : %i' % solver.NumConflicts())\n",
"print(' - branches : %i' % solver.NumBranches())\n",
"print(' - wall time : %f s' % solver.WallTime())\n",
"\n",
"groups = {}\n",
"for j in range(num_groups):\n",
" groups[j] = []\n",
"for i in range(num_values):\n",
" for j in range(num_groups):\n",
" if solver.Value(boo_x_i_j[(i, j)]):\n",
" groups[j].append(values[i])\n",
"\n",
"for j in range(num_groups):\n",
" print ('group %i: average = %0.2f [' % (\n",
" j, 1.0 * sum(groups[j]) / len(groups[j])), end='')\n",
" for v in groups[j]:\n",
" print(' %i' % v, end='')\n",
" print(' ]')"
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}