Skip to content

Commit 4d81a0d

Browse files
committed
torch: add take_along_axis
1 parent 1a37a0f commit 4d81a0d

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

array_api_compat/torch/_aliases.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
761761
axis = 0
762762
return torch.index_select(x, axis, indices, **kwargs)
763763

764+
765+
def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
766+
return torch.take_along_dim(x, indices, dim=axis)
767+
768+
764769
def sign(x: array, /) -> array:
765770
# torch sign() does not support complex numbers and does not propagate
766771
# nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -784,14 +789,14 @@ def sign(x: array, /) -> array:
784789
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
785790
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
786791
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',
787-
'min', 'clip', 'unstack', 'cumulative_sum', 'sort', 'prod', 'sum',
792+
'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum',
788793
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
789794
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
790795
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
791796
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
792797
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
793798
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
794799
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
795-
'take', 'sign']
800+
'take', 'take_along_axis', 'sign']
796801

797802
_all_ignore = ['torch', 'get_xp']

0 commit comments

Comments
 (0)