@@ -761,6 +761,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
761
761
axis = 0
762
762
return torch .index_select (x , axis , indices , ** kwargs )
763
763
764
+
765
+ def take_along_axis (x : array , indices : array , / , * , axis : int = - 1 ) -> array :
766
+ return torch .take_along_dim (x , indices , dim = axis )
767
+
768
+
764
769
def sign (x : array , / ) -> array :
765
770
# torch sign() does not support complex numbers and does not propagate
766
771
# nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -784,14 +789,14 @@ def sign(x: array, /) -> array:
784
789
'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
785
790
'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
786
791
'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
787
- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'sort' , 'prod' , 'sum' ,
792
+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , ' sort' , 'prod' , 'sum' ,
788
793
'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
789
794
'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
790
795
'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
791
796
'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
792
797
'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
793
798
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
794
799
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
795
- 'take' , 'sign' ]
800
+ 'take' , 'take_along_axis' , ' sign' ]
796
801
797
802
_all_ignore = ['torch' , 'get_xp' ]
0 commit comments