Skip to content

Commit 45d302a

Browse files
committed
ENH: add cumulative_prod
1 parent beac55b commit 45d302a

File tree

5 files changed

+35
-1
lines changed

5 files changed

+35
-1
lines changed

array_api_compat/common/_aliases.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,36 @@ def cumulative_sum(
297297
)
298298
return res
299299

300+
301+
def cumulative_prod(
302+
x: ndarray,
303+
/,
304+
xp,
305+
*,
306+
axis: Optional[int] = None,
307+
dtype: Optional[Dtype] = None,
308+
include_initial: bool = False,
309+
**kwargs
310+
) -> ndarray:
311+
wrapped_xp = array_namespace(x)
312+
313+
if axis is None:
314+
if x.ndim > 1:
315+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
316+
axis = 0
317+
318+
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
319+
320+
# np.cumprod does not support include_initial
321+
if include_initial:
322+
initial_shape = list(x.shape)
323+
initial_shape[axis] = 1
324+
res = xp.concatenate(
325+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
326+
axis=axis,
327+
)
328+
return res
329+
300330
# The min and max argument names in clip are different and not optional in numpy, and type
301331
# promotion behavior is different.
302332
def clip(
@@ -549,7 +579,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
549579
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
550580
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
551581
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
552-
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
582+
'astype', 'std', 'var', 'cumulative_sum', 'cumulative_prod', 'clip', 'permute_dims',
553583
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
554584
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
555585
'unstack', 'sign']

array_api_compat/cupy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
std = get_xp(cp)(_aliases.std)
5151
var = get_xp(cp)(_aliases.var)
5252
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
53+
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
5354
clip = get_xp(cp)(_aliases.clip)
5455
permute_dims = get_xp(cp)(_aliases.permute_dims)
5556
reshape = get_xp(cp)(_aliases.reshape)

array_api_compat/dask/array/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def _dask_arange(
8686
std = get_xp(da)(_aliases.std)
8787
var = get_xp(da)(_aliases.var)
8888
cumulative_sum = get_xp(da)(_aliases.cumulative_sum)
89+
cumulative_prod = get_xp(da)(_aliases.cumulative_prod)
8990
empty = get_xp(da)(_aliases.empty)
9091
empty_like = get_xp(da)(_aliases.empty_like)
9192
full = get_xp(da)(_aliases.full)

array_api_compat/numpy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
std = get_xp(np)(_aliases.std)
5151
var = get_xp(np)(_aliases.var)
5252
cumulative_sum = get_xp(np)(_aliases.cumulative_sum)
53+
cumulative_prod = get_xp(np)(_aliases.cumulative_prod)
5354
clip = get_xp(np)(_aliases.clip)
5455
permute_dims = get_xp(np)(_aliases.permute_dims)
5556
reshape = get_xp(np)(_aliases.reshape)

array_api_compat/torch/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
204204
clip = get_xp(torch)(_aliases_clip)
205205
unstack = get_xp(torch)(_aliases_unstack)
206206
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
207+
cumulative_prod = get_xp(torch)(_aliases_cumulative_prod)
207208

208209
# torch.sort also returns a tuple
209210
# https://github.com/pytorch/pytorch/issues/70921

0 commit comments

Comments
 (0)