This commit is contained in:
Laurent Perron
2023-03-05 08:18:45 +01:00
parent 1d061ed19e
commit 471a45dd55
4 changed files with 12 additions and 11 deletions

View File

@@ -47,6 +47,7 @@ IntegerT = Union[numbers.Integral, np.integer]
LinearExprT = Union['LinearExpr', NumberT]
ConstraintT = Union['VarCompVar', 'BoundedLinearExpression', bool]
ShapeT = Union[IntegerT, Sequence[IntegerT]]
VariablesT = Union['VariableContainer', 'Variable']
NumpyFuncT = Callable[[
'VariableContainer',
Optional[Union[NumberT, npt.NDArray[np.number], Sequence[NumberT]]],
@@ -578,10 +579,7 @@ class VariableContainer(mixins.NDArrayOperatorsMixin):
def variable_indices(self) -> npt.NDArray[np.int32]:
return self.__variable_indices
def __getitem__(
self,
pos: SliceT,
) -> Union['VariableContainer', Variable]:
def __getitem__(self, pos: SliceT) -> VariablesT:
# delegate the treatment of the 'pos' query to __variable_indices.
index_or_slice: Union[np.int32, npt.NDArray[np.int32]] = (
self.__variable_indices[pos])
@@ -590,11 +588,7 @@ class VariableContainer(mixins.NDArrayOperatorsMixin):
else:
return VariableContainer(self.__helper, index_or_slice)
def index_at(
self,
pos: Union[slice, int, List[int], Tuple[Union[int, slice, List[int]],
...]],
) -> Union[np.int32, npt.NDArray[np.int32]]:
def index_at(self, pos: SliceT) -> Union[np.int32, npt.NDArray[np.int32]]:
"""Returns the index of the variable at the position 'pos'."""
return self.__variable_indices[pos]
@@ -616,6 +610,10 @@ class VariableContainer(mixins.NDArrayOperatorsMixin):
"""Returns the number of variables in the numpy array."""
return self.__variable_indices.size
def ravel(self) -> 'VariableContainer':
"""returns the ravel array of variables."""
return VariableContainer(self.__helper, self.__variable_indices.ravel())
def flatten(self) -> 'VariableContainer':
"""returns the flattened array of variables."""
return VariableContainer(self.__helper,