|
4 | 4 | from builtins import all as _builtin_all, any as _builtin_any
|
5 | 5 |
|
6 | 6 | from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
|
7 |
| - vecdot as _aliases_vecdot, clip as _aliases_clip) |
| 7 | + vecdot as _aliases_vecdot, clip as |
| 8 | + _aliases_clip, unstack as _aliases_unstack,) |
8 | 9 | from .._internal import get_xp
|
9 | 10 |
|
10 | 11 | import torch
|
@@ -191,6 +192,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
|
191 | 192 | return torch.amin(x, axis, keepdims=keepdims)
|
192 | 193 |
|
193 | 194 | clip = get_xp(torch)(_aliases_clip)
|
| 195 | +unstack = get_xp(torch)(_aliases_unstack) |
194 | 196 |
|
195 | 197 | # torch.sort also returns a tuple
|
196 | 198 | # https://github.com/pytorch/pytorch/issues/70921
|
@@ -709,7 +711,7 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
|
709 | 711 | 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
|
710 | 712 | 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
|
711 | 713 | 'hypot', 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal',
|
712 |
| - 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', |
| 714 | + 'pow', 'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort', |
713 | 715 | 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat',
|
714 | 716 | 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where',
|
715 | 717 | 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros',
|
|
0 commit comments