Skip to content

Commit 1a37a0f

Browse files
committed
torch: add diff
1 parent 0588591 commit 1a37a0f

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

array_api_compat/torch/_aliases.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
clip as _aliases_clip,
99
unstack as _aliases_unstack,
1010
cumulative_sum as _aliases_cumulative_sum,
11+
cumulative_prod as _aliases_cumulative_prod,
1112
)
1213
from .._internal import get_xp
1314

@@ -505,6 +506,20 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
505506
raise ValueError("nonzero() does not support zero-dimensional arrays")
506507
return torch.nonzero(x, as_tuple=True, **kwargs)
507508

509+
510+
# torch uses `dim` instead of `axis`
511+
def diff(
512+
x: array,
513+
/,
514+
*,
515+
axis: int = -1,
516+
n: int = 1,
517+
prepend: Optional[array] = None,
518+
append: Optional[array] = None,
519+
) -> array:
520+
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
521+
522+
508523
# torch uses `dim` instead of `axis`
509524
def count_nonzero(
510525
x: array,
@@ -765,7 +780,7 @@ def sign(x: array, /) -> array:
765780
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
766781
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
767782
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
768-
'divide',
783+
'diff', 'divide',
769784
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
770785
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
771786
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

0 commit comments

Comments
 (0)