Skip to content

Commit dc8d4ed

Browse files
committed
torch: add take_along_axis
1 parent b3b3a05 commit dc8d4ed

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

array_api_compat/torch/_aliases.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
744744
axis = 0
745745
return torch.index_select(x, axis, indices, **kwargs)
746746

747+
748+
def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
749+
return torch.take_along_dim(x, indices, dim=axis)
750+
751+
747752
def sign(x: array, /) -> array:
748753
# torch sign() does not support complex numbers and does not propagate
749754
# nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -775,6 +780,6 @@ def sign(x: array, /) -> array:
775780
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
776781
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
777782
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
778-
'take', 'sign']
783+
'take', 'take_along_axis', 'sign']
779784

780785
_all_ignore = ['torch', 'get_xp']

0 commit comments

Comments
 (0)