[CP-SAT] support numpy integers in most of the API

This commit is contained in:
Laurent Perron
2021-11-15 11:48:17 +01:00
parent 2493b4f028
commit 9b75a7118d
2 changed files with 113 additions and 99 deletions

View File

@@ -45,14 +45,13 @@ rather than for solving specific optimization problems.
"""
import collections
import numbers
import threading
import time
import warnings
from ortools.sat import cp_model_pb2
from ortools.sat import sat_parameters_pb2
from ortools.sat.python import cp_model_helper
from ortools.sat.python import cp_model_helper as cmh
from ortools.sat import pywrapsat
from ortools.util import sorted_interval_list
@@ -269,17 +268,17 @@ class LinearExpr(object):
'please use CpModel.AddAbsEquality')
def __add__(self, arg):
if isinstance(arg, numbers.Integral) and arg == 0:
if cmh.IsIntegral(arg) and int(arg) == 0:
return self
return _SumArray([self, arg])
def __radd__(self, arg):
if isinstance(arg, numbers.Integral) and arg == 0:
if cmh.IsIntegral(arg) and int(arg) == 0:
return self
return _SumArray([self, arg])
def __sub__(self, arg):
if isinstance(arg, numbers.Integral) and arg == 0:
if cmh.IsIntegral(arg) and int(arg) == 0:
return self
return _SumArray([self, -arg])
@@ -287,20 +286,19 @@ class LinearExpr(object):
return _SumArray([-self, arg])
def __mul__(self, arg):
if isinstance(arg, numbers.Integral):
if arg == 1:
return self
elif arg == 0:
return 0
cp_model_helper.AssertIsInt64(arg)
return _ProductCst(self, arg)
else:
raise TypeError('Not an integer linear expression: ' + str(arg))
def __rmul__(self, arg):
cp_model_helper.AssertIsInt64(arg)
arg = cmh.AssertIsInt64(arg)
if arg == 1:
return self
elif arg == 0:
return 0
return _ProductCst(self, arg)
def __rmul__(self, arg):
arg = cmh.AssertIsInt64(arg)
if arg == 1:
return self
elif arg == 0:
return 0
return _ProductCst(self, arg)
def __div__(self, _):
@@ -356,29 +354,29 @@ class LinearExpr(object):
def __eq__(self, arg):
if arg is None:
return False
if isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
if cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
return BoundedLinearExpression(self, [arg, arg])
else:
return BoundedLinearExpression(self - arg, [0, 0])
def __ge__(self, arg):
if isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
if cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
return BoundedLinearExpression(self, [arg, INT_MAX])
else:
return BoundedLinearExpression(self - arg, [0, INT_MAX])
def __le__(self, arg):
if isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
if cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
return BoundedLinearExpression(self, [INT_MIN, arg])
else:
return BoundedLinearExpression(self - arg, [INT_MIN, 0])
def __lt__(self, arg):
if isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
if cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
if arg == INT_MIN:
raise ArithmeticError('< INT_MIN is not supported')
return BoundedLinearExpression(self, [INT_MIN, arg - 1])
@@ -386,8 +384,8 @@ class LinearExpr(object):
return BoundedLinearExpression(self - arg, [INT_MIN, -1])
def __gt__(self, arg):
if isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
if cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
if arg == INT_MAX:
raise ArithmeticError('> INT_MAX is not supported')
return BoundedLinearExpression(self, [arg + 1, INT_MAX])
@@ -397,8 +395,8 @@ class LinearExpr(object):
def __ne__(self, arg):
if arg is None:
return True
if isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
if cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
if arg == INT_MAX:
return BoundedLinearExpression(self, [INT_MIN, INT_MAX - 1])
elif arg == INT_MIN:
@@ -415,7 +413,7 @@ class _ProductCst(LinearExpr):
"""Represents the product of a LinearExpr by a constant."""
def __init__(self, expr, coef):
cp_model_helper.AssertIsInt64(coef)
coef = cmh.AssertIsInt64(coef)
if isinstance(expr, _ProductCst):
self.__expr = expr.Expression()
self.__coef = expr.Coefficient() * coef
@@ -447,8 +445,8 @@ class _SumArray(LinearExpr):
self.__expressions = []
self.__constant = constant
for x in expressions:
if isinstance(x, numbers.Integral):
cp_model_helper.AssertIsInt64(x)
if cmh.IsIntegral(x):
x = cmh.AssertIsInt64(x)
self.__constant += x
elif isinstance(x, LinearExpr):
self.__expressions.append(x)
@@ -485,11 +483,11 @@ class _ScalProd(LinearExpr):
'In the LinearExpr.ScalProd method, the expression array and the '
' coefficient array must have the same length.')
for e, c in zip(expressions, coefficients):
cp_model_helper.AssertIsInt64(c)
c = cmh.AssertIsInt64(c)
if c == 0:
continue
if isinstance(e, numbers.Integral):
cp_model_helper.AssertIsInt64(e)
if cmh.IsIntegral(e):
e = cmh.AssertIsInt64(e)
self.__constant += e * c
elif isinstance(e, LinearExpr):
self.__expressions.append(e)
@@ -562,8 +560,8 @@ class IntVar(LinearExpr):
# model is a CpModelProto, domain is a Domain, and name is a string.
# case 2:
# model is a CpModelProto, domain is an index (int), and name is None.
if isinstance(domain, numbers.Integral) and name is None:
self.__index = domain
if cmh.IsIntegral(domain) and name is None:
self.__index = int(domain)
self.__var = model.variables[domain]
else:
self.__index = len(model.variables)
@@ -740,12 +738,12 @@ class Constraint(object):
self.
"""
if isinstance(boolvar, numbers.Integral) and boolvar == 1:
if cmh.IsIntegral(boolvar) and int(boolvar) == 1:
# Always true. Do nothing.
pass
elif isinstance(boolvar, list):
for b in boolvar:
if isinstance(b, numbers.Integral) and b == 1:
if cmh.IsIntegral(b) and int(b) == 1:
pass
else:
self.__constraint.enforcement_literal.append(b.Index())
@@ -855,8 +853,8 @@ def ObjectIsATrueLiteral(literal):
proto = literal.Not().Proto()
return (len(proto.domain) == 2 and proto.domain[0] == 0 and
proto.domain[1] == 0)
if isinstance(literal, numbers.Integral):
return literal == 1
if cmh.IsIntegral(literal):
return int(literal) == 1
return False
@@ -870,8 +868,8 @@ def ObjectIsAFalseLiteral(literal):
proto = literal.Not().Proto()
return (len(proto.domain) == 2 and proto.domain[0] == 1 and
proto.domain[1] == 1)
if isinstance(literal, numbers.Integral):
return literal == 0
if cmh.IsIntegral(literal):
return int(literal) == 0
return False
@@ -948,16 +946,14 @@ class CpModel(object):
for t in coeffs_map.items():
if not isinstance(t[0], IntVar):
raise TypeError('Wrong argument' + str(t))
cp_model_helper.AssertIsInt64(t[1])
c = cmh.AssertIsInt64(t[1])
model_ct.linear.vars.append(t[0].Index())
model_ct.linear.coeffs.append(t[1])
model_ct.linear.domain.extend([
cp_model_helper.CapSub(x, constant)
for x in domain.FlattenedIntervals()
])
model_ct.linear.coeffs.append(c)
model_ct.linear.domain.extend(
[cmh.CapSub(x, constant) for x in domain.FlattenedIntervals()])
return ct
elif isinstance(linear_expr, numbers.Integral):
if not domain.Contains(linear_expr):
elif cmh.IsIntegral(linear_expr):
if not domain.Contains(int(linear_expr)):
return self.AddBoolOr([]) # Evaluate to false.
# Nothing to do otherwise.
else:
@@ -1009,8 +1005,8 @@ class CpModel(object):
if not variables:
raise ValueError('AddElement expects a non-empty variables array')
if isinstance(index, numbers.Integral):
return self.Add(list(variables)[index] == target)
if cmh.IsIntegral(index):
return self.Add(list(variables)[int(index)] == target)
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
@@ -1047,11 +1043,11 @@ class CpModel(object):
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
for arc in arcs:
cp_model_helper.AssertIsInt32(arc[0])
cp_model_helper.AssertIsInt32(arc[1])
tail = cmh.AssertIsInt32(arc[0])
head = cmh.AssertIsInt32(arc[1])
lit = self.GetOrMakeBooleanIndex(arc[2])
model_ct.circuit.tails.append(arc[0])
model_ct.circuit.heads.append(arc[1])
model_ct.circuit.tails.append(tail)
model_ct.circuit.heads.append(head)
model_ct.circuit.literals.append(lit)
return ct
@@ -1089,9 +1085,10 @@ class CpModel(object):
for t in tuples_list:
if len(t) != arity:
raise TypeError('Tuple ' + str(t) + ' has the wrong arity')
ar = []
for v in t:
cp_model_helper.AssertIsInt64(v)
model_ct.table.values.extend(t)
ar.append(cmh.AssertIsInt64(v))
model_ct.table.values.extend(ar)
return ct
def AddForbiddenAssignments(self, variables, tuples_list):
@@ -1181,21 +1178,21 @@ class CpModel(object):
model_ct = self.__model.constraints[ct.Index()]
model_ct.automaton.vars.extend(
[self.GetOrMakeIndex(x) for x in transition_variables])
cp_model_helper.AssertIsInt64(starting_state)
starting_state = cmh.AssertIsInt64(starting_state)
model_ct.automaton.starting_state = starting_state
for v in final_states:
cp_model_helper.AssertIsInt64(v)
v = cmh.AssertIsInt64(v)
model_ct.automaton.final_states.append(v)
for t in transition_triples:
if len(t) != 3:
raise TypeError('Tuple ' + str(t) +
' has the wrong arity (!= 3)')
cp_model_helper.AssertIsInt64(t[0])
cp_model_helper.AssertIsInt64(t[1])
cp_model_helper.AssertIsInt64(t[2])
model_ct.automaton.transition_tail.append(t[0])
model_ct.automaton.transition_label.append(t[1])
model_ct.automaton.transition_head.append(t[2])
tail = cmh.AssertIsInt64(t[0])
label = cmh.AssertIsInt64(t[1])
head = cmh.AssertIsInt64(t[2])
model_ct.automaton.transition_tail.append(tail)
model_ct.automaton.transition_label.append(label)
model_ct.automaton.transition_head.append(head)
return ct
def AddInverse(self, variables, inverse_variables):
@@ -1435,9 +1432,10 @@ class CpModel(object):
"""Adds `target == num // denom` (integer division rounded towards 0)."""
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
model_ct.int_div.exprs.append(self.ParseLinearExpression(num))
model_ct.int_div.exprs.append(self.ParseLinearExpression(denom))
model_ct.int_div.target.CopyFrom(self.ParseLinearExpression(target))
model_ct.int_div.vars.extend(
[self.GetOrMakeIndex(num),
self.GetOrMakeIndex(denom)])
model_ct.int_div.target = self.GetOrMakeIndex(target)
return ct
def AddAbsEquality(self, target, expr):
@@ -1453,20 +1451,28 @@ class CpModel(object):
"""Adds `target = var % mod`."""
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
model_ct.int_mod.exprs.append(self.ParseLinearExpression(var))
model_ct.int_mod.exprs.append(self.ParseLinearExpression(mod))
model_ct.int_mod.target.CopyFrom(self.ParseLinearExpression(target))
model_ct.int_mod.vars.extend(
[self.GetOrMakeIndex(var),
self.GetOrMakeIndex(mod)])
model_ct.int_mod.target = self.GetOrMakeIndex(target)
return ct
def AddMultiplicationEquality(self, target, expressions):
def AddMultiplicationEquality(self, target, variables):
"""Adds `target == variables[0] * .. * variables[n]`."""
ct = Constraint(self.__model.constraints)
model_ct = self.__model.constraints[ct.Index()]
model_ct.int_prod.exprs.extend(
[self.ParseLinearExpression(expr) for expr in expressions])
model_ct.int_prod.target.CopyFrom(self.ParseLinearExpression(target))
model_ct.int_prod.vars.extend(
[self.GetOrMakeIndex(x) for x in variables])
model_ct.int_prod.target = self.GetOrMakeIndex(target)
return ct
def AddProdEquality(self, target, variables):
"""Deprecated, use AddMultiplicationEquality."""
warnings.warn(
'AddProdEquality is deprecated; use' + 'AddMultiplicationEquality.',
DeprecationWarning)
return self.AddMultiplicationEquality(target, variables)
# Scheduling support
def NewIntervalVar(self, start, size, end, name):
@@ -1521,7 +1527,7 @@ class CpModel(object):
Returns:
An `IntervalVar` object.
"""
cp_model_helper.AssertIsInt64(size)
size = cmh.AssertIsInt64(size)
start_expr = self.ParseLinearExpression(start)
size_expr = self.ParseLinearExpression(size)
end_expr = self.ParseLinearExpression(start + size)
@@ -1592,7 +1598,7 @@ class CpModel(object):
Returns:
An `IntervalVar` object.
"""
cp_model_helper.AssertIsInt64(size)
size = cmh.AssertIsInt64(size)
start_expr = self.ParseLinearExpression(start)
size_expr = self.ParseLinearExpression(size)
end_expr = self.ParseLinearExpression(start + size)
@@ -1736,8 +1742,8 @@ class CpModel(object):
elif (isinstance(arg, _ProductCst) and
isinstance(arg.Expression(), IntVar) and arg.Coefficient() == -1):
return -arg.Expression().Index() - 1
elif isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsInt64(arg)
elif cmh.IsIntegral(arg):
arg = cmh.AssertIsInt64(arg)
return self.GetOrMakeIndexFromConstant(arg)
else:
raise TypeError('NotSupported: model.GetOrMakeIndex(' + str(arg) +
@@ -1751,9 +1757,9 @@ class CpModel(object):
elif isinstance(arg, _NotBooleanVariable):
self.AssertIsBooleanVariable(arg.Not())
return arg.Index()
elif isinstance(arg, numbers.Integral):
cp_model_helper.AssertIsBoolean(arg)
return self.GetOrMakeIndexFromConstant(arg)
elif cmh.IsIntegral(arg):
cmh.AssertIsBoolean(arg)
return self.GetOrMakeIndexFromConstant(int(arg))
else:
raise TypeError('NotSupported: model.GetOrMakeBooleanIndex(' +
str(arg) + ')')
@@ -1782,8 +1788,8 @@ class CpModel(object):
"""Returns a LinearExpressionProto built from a LinearExpr instance."""
result = cp_model_pb2.LinearExpressionProto()
mult = -1 if negate else 1
if isinstance(linear_expr, numbers.Integral):
result.offset = linear_expr * mult
if cmh.IsIntegral(linear_expr):
result.offset = int(linear_expr) * mult
return result
if isinstance(linear_expr, IntVar):
@@ -1796,9 +1802,9 @@ class CpModel(object):
for t in coeffs_map.items():
if not isinstance(t[0], IntVar):
raise TypeError('Wrong argument' + str(t))
cp_model_helper.AssertIsInt64(t[1])
c = cmh.AssertIsInt64(t[1])
result.vars.append(t[0].Index())
result.coeffs.append(t[1] * mult)
result.coeffs.append(c * mult)
return result
def _SetObjective(self, obj, minimize):
@@ -1828,8 +1834,8 @@ class CpModel(object):
self.__model.objective.vars.append(v.Index())
else:
self.__model.objective.vars.append(self.Negated(v.Index()))
elif isinstance(obj, numbers.Integral):
self.__model.objective.offset = obj
elif cmh.IsIntegral(obj):
self.__model.objective.offset = int(obj)
self.__model.objective.scaling_factor = 1
else:
raise TypeError('TypeError: ' + str(obj) +
@@ -1934,8 +1940,8 @@ class CpModel(object):
def EvaluateLinearExpr(expression, solution):
"""Evaluate a linear expression against a solution."""
if isinstance(expression, numbers.Integral):
return expression
if cmh.IsIntegral(expression):
return int(expression)
if not isinstance(expression, LinearExpr):
raise TypeError('Cannot interpret %s as a linear expression.' %
expression)
@@ -1963,7 +1969,7 @@ def EvaluateLinearExpr(expression, solution):
def EvaluateBooleanExpression(literal, solution):
"""Evaluate a boolean expression against a solution."""
if isinstance(literal, numbers.Integral):
if cmh.IsIntegral(literal):
return bool(literal)
elif isinstance(literal, IntVar) or isinstance(literal,
_NotBooleanVariable):
@@ -2172,7 +2178,7 @@ class CpSolverSolutionCallback(pywrapsat.SolutionCallback):
"""
if not self.HasResponse():
raise RuntimeError('Solve() has not be called.')
if isinstance(lit, numbers.Integral):
if cmh.IsIntegral(lit):
return bool(lit)
elif isinstance(lit, IntVar) or isinstance(lit, _NotBooleanVariable):
index = lit.Index()
@@ -2196,8 +2202,8 @@ class CpSolverSolutionCallback(pywrapsat.SolutionCallback):
"""
if not self.HasResponse():
raise RuntimeError('Solve() has not be called.')
if isinstance(expression, numbers.Integral):
return expression
if cmh.IsIntegral(expression):
return int(expression)
if not isinstance(expression, LinearExpr):
raise TypeError('Cannot interpret %s as a linear expression.' %
expression)

View File

@@ -13,6 +13,7 @@
"""helpers methods for the cp_model module."""
import numbers
import numpy as np
INT_MIN = -9223372036854775808 # hardcoded to be platform independent.
INT_MAX = 9223372036854775807
@@ -20,25 +21,32 @@ INT32_MIN = -2147483648
INT32_MAX = 2147483647
def IsIntegral(x):
"""Checks if x has either a number.Integral or a np.integer type."""
return isinstance(x, numbers.Integral) or isinstance(x, np.integer)
def AssertIsInt64(x):
"""Asserts that x is integer and x is in [min_int_64, max_int_64]."""
if not isinstance(x, numbers.Integral):
if not IsIntegral(x):
raise TypeError('Not an integer: %s' % x)
if x < INT_MIN or x > INT_MAX:
raise OverflowError('Does not fit in an int64_t: %s' % x)
return int(x)
def AssertIsInt32(x):
"""Asserts that x is integer and x is in [min_int_32, max_int_32]."""
if not isinstance(x, numbers.Integral):
if not IsIntegral(x):
raise TypeError('Not an integer: %s' % x)
if x < INT32_MIN or x > INT32_MAX:
raise OverflowError('Does not fit in an int32_t: %s' % x)
return int(x)
def AssertIsBoolean(x):
"""Asserts that x is 0 or 1."""
if not isinstance(x, numbers.Integral) or x < 0 or x > 1:
if not IsIntegral(x) or x < 0 or x > 1:
raise TypeError('Not an boolean: %s' % x)