Skip to content

Commit 1fd3429

Browse files
authored
Add specification for computing the cumulative sum
PR-URL: data-apis#653
1 parent 5a14534 commit 1fd3429

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

spec/draft/API_specification/statistical_functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Objects in API
1818
:toctree: generated
1919
:template: method.rst
2020

21+
cumulative_sum
2122
max
2223
mean
2324
min

src/array_api_stubs/_draft/statistical_functions.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,67 @@
1-
__all__ = ["max", "mean", "min", "prod", "std", "sum", "var"]
1+
__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
22

33

44
from ._types import Optional, Tuple, Union, array, dtype
55

66

7+
def cumulative_sum(
8+
x: array,
9+
/,
10+
*,
11+
axis: Optional[int] = None,
12+
dtype: Optional[dtype] = None,
13+
include_initial: bool = False,
14+
) -> array:
15+
"""
16+
Calculates the cumulative sum of elements in the input array ``x``.
17+
18+
Parameters
19+
----------
20+
x: array
21+
input array. Should have a numeric data type.
22+
axis: Optional[int]
23+
axis along which a cumulative sum must be computed. If ``axis`` is negative, the function must determine the axis along which to compute a cumulative sum by counting from the last dimension.
24+
25+
If ``x`` is a one-dimensional array, providing an ``axis`` is optional; however, if ``x`` has more than one dimension, providing an ``axis`` is required.
26+
dtype: Optional[dtype]
27+
data type of the returned array. If ``None``,
28+
29+
- if the default data type corresponding to the data type "kind" (integer, real-valued floating-point, or complex floating-point) of ``x`` has a smaller range of values than the data type of ``x`` (e.g., ``x`` has data type ``int64`` and the default data type is ``int32``, or ``x`` has data type ``uint64`` and the default data type is ``int64``), the returned array must have the same data type as ``x``.
30+
31+
- if the default data type corresponding to the data type "kind" of ``x`` has the same or a larger range of values than the data type of ``x``,
32+
33+
- if ``x`` has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
34+
- if ``x`` has a complex floating-point data type, the returned array must have the default complex floating-point data type.
35+
- if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type.
36+
- if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type).
37+
38+
If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the sum. Default: ``None``.
39+
40+
.. note::
41+
keyword argument is intended to help prevent data type overflows.
42+
43+
include_initial: bool
44+
boolean indicating whether to include the initial value as the first value in the output. By convention, the initial value must be the additive identity (i.e., zero). Default: ``False``.
45+
46+
Returns
47+
-------
48+
out: array
49+
an array containing the cumulative sums. The returned array must have a data type as described by the ``dtype`` parameter above.
50+
51+
Let ``N`` be the size of the axis along which to compute the cumulative sum. The returned array must have a shape determined according to the following rules:
52+
53+
- if ``include_initial`` is ``True``, the returned array must have the same shape as ``x``, except the size of the axis along which to compute the cumulative sum must be ``N+1``.
54+
- if ``include_initial`` is ``False``, the returned array must have the same shape as ``x``.
55+
56+
Notes
57+
-----
58+
59+
**Special Cases**
60+
61+
For both real-valued and complex floating-point operands, special cases must be handled as if the operation is implemented by successive application of :func:`~array_api.add`.
62+
"""
63+
64+
765
def max(
866
x: array,
967
/,

0 commit comments

Comments
 (0)