use python3.7+ typing in model_builder and cp_model

This commit is contained in:
Laurent Perron
2023-07-21 16:42:55 -07:00
parent b98044950f
commit 9a5c7e8dad
2 changed files with 9 additions and 9 deletions

View File

@@ -37,7 +37,7 @@ import dataclasses
import math
import numbers
import typing
from typing import Callable, Mapping, Optional, Sequence, Union
from typing import Callable, Mapping, Optional, Sequence, Union, cast
import numpy as np
from numpy import typing as npt
@@ -1017,8 +1017,8 @@ class ModelBuilder:
if (
isinstance(is_integral, bool)
and is_integral
and isinstance(lower_bounds, NumberT)
and isinstance(upper_bounds, NumberT)
and mbn.is_a_number(lower_bounds)
and mbn.is_a_number(upper_bounds)
and math.isfinite(lower_bounds)
and math.isfinite(upper_bounds)
and math.ceil(lower_bounds) > math.floor(upper_bounds)
@@ -1583,8 +1583,8 @@ def _as_flat_linear_expression(base_expr: LinearExprT) -> _LinearExpression:
to_process.append((expr._right, coeff))
elif isinstance(expr, Variable):
terms[expr] += coeff
elif isinstance(expr, NumberT):
offset += coeff * expr
elif mbn.is_a_number(expr):
offset += coeff * cast(NumberT, expr)
elif isinstance(expr, _Product):
to_process.append((expr._expression, coeff * expr._coefficient))
elif isinstance(expr, _LinearExpression):
@@ -1680,7 +1680,7 @@ def _convert_to_series_and_validate_index(
TypeError: If the type of `value_or_series` is not recognized.
ValueError: If the index does not match.
"""
if isinstance(value_or_series, (bool, NumberT)):
if mbn.is_a_number(value_or_series) or isinstance(value_or_series, bool):
result = pd.Series(data=value_or_series, index=index)
elif isinstance(value_or_series, pd.Series):
if value_or_series.index.equals(index):

View File

@@ -1188,8 +1188,8 @@ class CpModel:
if not name.isidentifier():
raise ValueError("name={} is not a valid identifier".format(name))
if (
isinstance(lower_bounds, IntegralT)
and isinstance(upper_bounds, IntegralT)
cmh.is_integral(lower_bounds)
and cmh.is_integral(upper_bounds)
and lower_bounds > upper_bounds
):
raise ValueError(
@@ -2984,7 +2984,7 @@ def _ConvertToSeriesAndValidateIndex(
TypeError: If the type of `value_or_series` is not recognized.
ValueError: If the index does not match.
"""
if isinstance(value_or_series, (bool, IntegralT)):
if cmh.is_integral(value_or_series) or isinstance(value_or_series, bool):
result = pd.Series(data=value_or_series, index=index)
elif isinstance(value_or_series, pd.Series):
if value_or_series.index.equals(index):