Skip to content

Commit 5eb7abf

Browse files
committed
Add unstack to all wrapped libraries
1 parent e383441 commit 5eb7abf

File tree

5 files changed

+24
-3
lines changed

5 files changed

+24
-3
lines changed

array_api_compat/common/_aliases.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -485,10 +485,16 @@ def isdtype(
485485
# array_api_strict implementation will be very strict.
486486
return dtype == kind
487487

488+
# unstack is a new function in the 2023.12 array API standard
489+
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
490+
if x.ndim == 0:
491+
raise ValueError("Input array must be at least 1-d.")
492+
return tuple(xp.moveaxis(x, axis, 0))
493+
488494
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
489495
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
490496
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
491497
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
492498
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape',
493499
'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul',
494-
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
500+
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack']

array_api_compat/cupy/_aliases.py

+6
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,17 @@ def asarray(
112112
vecdot = cp.vecdot
113113
else:
114114
vecdot = get_xp(cp)(_aliases.vecdot)
115+
115116
if hasattr(cp, 'isdtype'):
116117
isdtype = cp.isdtype
117118
else:
118119
isdtype = get_xp(cp)(_aliases.isdtype)
119120

121+
if hasattr(cp, 'unstack'):
122+
unstack = cp.unstack
123+
else:
124+
unstack = get_xp(cp)(_aliases.unstack)
125+
120126
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
121127
'acosh', 'asin', 'asinh', 'atan', 'atan2',
122128
'atanh', 'bitwise_left_shift', 'bitwise_invert',

array_api_compat/dask/array/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import dask.array as da
4343

4444
isdtype = get_xp(np)(_aliases.isdtype)
45+
unstack = get_xp(da)(_aliases.unstack)
4546
astype = _aliases.astype
4647

4748
# Common aliases

array_api_compat/numpy/_aliases.py

+6
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,17 @@ def asarray(
117117
vecdot = np.vecdot
118118
else:
119119
vecdot = get_xp(np)(_aliases.vecdot)
120+
120121
if hasattr(np, 'isdtype'):
121122
isdtype = np.isdtype
122123
else:
123124
isdtype = get_xp(np)(_aliases.isdtype)
124125

126+
if hasattr(np, 'unstack'):
127+
unstack = np.unstack
128+
else:
129+
unstack = get_xp(np)(_aliases.unstack)
130+
125131
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
126132
'acosh', 'asin', 'asinh', 'atan', 'atan2',
127133
'atanh', 'bitwise_left_shift', 'bitwise_invert',

array_api_compat/torch/_aliases.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7-
vecdot as _aliases_vecdot, clip as _aliases_clip)
7+
vecdot as _aliases_vecdot, clip as
8+
_aliases_clip, unstack as _aliases_unstack,)
89
from .._internal import get_xp
910

1011
import torch
@@ -191,6 +192,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
191192
return torch.amin(x, axis, keepdims=keepdims)
192193

193194
clip = get_xp(torch)(_aliases_clip)
195+
unstack = get_xp(torch)(_aliases_unstack)
194196

195197
# torch.sort also returns a tuple
196198
# https://github.com/pytorch/pytorch/issues/70921
@@ -709,7 +711,7 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
709711
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
710712
'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
711713
'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal',
712-
'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'sort',
714+
'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort',
713715
'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat',
714716
'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where',
715717
'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros',

0 commit comments

Comments
 (0)