Skip to content

Commit b3b3a05

Browse files
committed
torch: add diff
1 parent 55a72f7 commit b3b3a05

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

array_api_compat/torch/_aliases.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
clip as _aliases_clip,
99
unstack as _aliases_unstack,
1010
cumulative_sum as _aliases_cumulative_sum,
11+
cumulative_prod as _aliases_cumulative_prod,
1112
)
1213
from .._internal import get_xp
1314

@@ -499,6 +500,20 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]:
499500
raise ValueError("nonzero() does not support zero-dimensional arrays")
500501
return torch.nonzero(x, as_tuple=True, **kwargs)
501502

503+
504+
# torch uses `dim` instead of `axis`
505+
def diff(
506+
x: array,
507+
/,
508+
*,
509+
axis: int = -1,
510+
n: int = 1,
511+
prepend: Optional[array] = None,
512+
append: Optional[array] = None,
513+
) -> array:
514+
return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append)
515+
516+
502517
# torch uses `dim` instead of `axis`
503518
def count_nonzero(
504519
x: array,
@@ -748,7 +763,7 @@ def sign(x: array, /) -> array:
748763
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
749764
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
750765
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',
751-
'divide',
766+
'diff', 'divide',
752767
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
753768
'less', 'less_equal', 'logaddexp', 'maximum', 'minimum',
754769
'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max',

0 commit comments

Comments
 (0)