Skip to content

Commit c0dd5b0

Browse files
committed
Add cumulative_sum to torch
1 parent cb9acd4 commit c0dd5b0

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

array_api_compat/common/_aliases.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ def cumulative_sum(
277277
include_initial: bool = False,
278278
**kwargs
279279
) -> ndarray:
280+
wrapped_xp = array_namespace(x)
281+
280282
# TODO: The standard is not clear about what should happen when x.ndim == 0.
281283
if axis is None:
282284
if x.ndim > 1:
@@ -290,7 +292,7 @@ def cumulative_sum(
290292
initial_shape = list(x.shape)
291293
initial_shape[axis] = 1
292294
res = xp.concatenate(
293-
[xp.zeros_like(res, shape=initial_shape), res],
295+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
294296
axis=axis,
295297
)
296298
return res

array_api_compat/torch/_aliases.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
from builtins import all as _builtin_all, any as _builtin_any
55

66
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7-
vecdot as _aliases_vecdot, clip as
8-
_aliases_clip, unstack as _aliases_unstack,)
7+
vecdot as _aliases_vecdot,
8+
clip as _aliases_clip,
9+
unstack as _aliases_unstack,
10+
cumulative_sum as _aliases_cumulative_sum,
11+
)
912
from .._internal import get_xp
1013

1114
from ._info import __array_namespace_info__
@@ -198,6 +201,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
198201

199202
clip = get_xp(torch)(_aliases_clip)
200203
unstack = get_xp(torch)(_aliases_unstack)
204+
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
201205

202206
# torch.sort also returns a tuple
203207
# https://github.com/pytorch/pytorch/issues/70921
@@ -732,11 +736,11 @@ def sign(x: array, /) -> array:
732736
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'divide',
733737
'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot',
734738
'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
735-
'remainder', 'subtract', 'max', 'min', 'clip', 'unstack', 'sort',
736-
'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat',
737-
'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where',
738-
'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros',
739-
'empty', 'tril', 'triu', 'expand_dims', 'astype',
739+
'remainder', 'subtract', 'max', 'min', 'clip', 'unstack',
740+
'cumulative_sum', 'sort', 'prod', 'sum', 'any', 'all', 'mean',
741+
'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
742+
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
743+
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
740744
'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult',
741745
'UniqueInverseResult', 'unique_all', 'unique_counts',
742746
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',

0 commit comments

Comments
 (0)