Skip to content

Commit 912362c

Browse files
committed
ENH: add cumulative_prod (untested)
1 parent f00a882 commit 912362c

File tree

3 files changed

+79
-2
lines changed

3 files changed

+79
-2
lines changed

array_api_strict/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,9 @@
305305

306306
__all__ += ["argsort", "sort"]
307307

308-
from ._statistical_functions import cumulative_sum, max, mean, min, prod, std, sum, var
308+
from ._statistical_functions import cumulative_sum, cumulative_prod, max, mean, min, prod, std, sum, var
309309

310-
__all__ += ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"]
310+
__all__ += ["cumulative_sum", "cumulative_prod", "max", "mean", "min", "prod", "std", "sum", "var"]
311311

312312
from ._utility_functions import all, any, diff
313313

array_api_strict/_dtypes.py

+28
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,34 @@ def __hash__(self):
127127
}
128128

129129

130+
def _bit_width(dtype):
131+
"""The bit width of an integer dtype"""
132+
if dtype == int8 or dtype == uint8:
133+
return 8
134+
elif dtype == int16 or dtype == uint16:
135+
return 16
136+
elif dtype == int32 or dtype == uint32:
137+
return 32
138+
elif dtype == int64 or dtype == uint64:
139+
return 64
140+
else:
141+
raise ValueError(f"_bit_width: {dtype = } not understood.")
142+
143+
144+
def _get_unsigned_from_signed(dtype):
145+
"""Return an unsigned integral dtype to match the input dtype."""
146+
if dtype == int8:
147+
return uint8
148+
elif dtype == int16:
149+
return uint16
150+
elif dtype == int32:
151+
return uint32
152+
elif dtype == int64:
153+
return uint64
154+
else:
155+
raise ValueError(f"_unsigned_from_signed: {dtype = } not understood.")
156+
157+
130158
# Note: the spec defines a restricted type promotion table compared to NumPy.
131159
# In particular, cross-kind promotions like integer + float or boolean +
132160
# integer are not allowed, even for functions that accept both kinds.

array_api_strict/_statistical_functions.py

+49
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
_real_numeric_dtypes,
66
_floating_dtypes,
77
_numeric_dtypes,
8+
_integer_dtypes
89
)
10+
from . import _dtypes
11+
from . import _info
912
from ._array_object import Array
1013
from ._dtypes import float32, complex64
1114
from ._flags import requires_api_version, get_array_api_strict_flags
@@ -47,6 +50,52 @@ def cumulative_sum(
4750
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
4851
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
4952

53+
54+
@requires_api_version('2024.12')
55+
def cumulative_prod(
56+
x: Array,
57+
/,
58+
*,
59+
axis: Optional[int] = None,
60+
dtype: Optional[Dtype] = None,
61+
include_initial: bool = False,
62+
) -> Array:
63+
if x.dtype not in _numeric_dtypes:
64+
raise TypeError("Only numeric dtypes are allowed in cumulative_prod")
65+
if x.ndim == 0:
66+
raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod")
67+
68+
# TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance.
69+
if dtype is None:
70+
if x.dtype in _integer_dtypes:
71+
default_int = _info.__array_namespace_info__().default_dtypes()["integral"]
72+
if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int):
73+
if x.dtype in _dtypes._unsigned_integer_dtypes:
74+
# find the unsigned integer of the same width as `default_int`
75+
dtype = _dtypes._get_unsigned_from_signed(default_int)
76+
else:
77+
dtype = default_int
78+
else:
79+
dtype = x.dtype
80+
else:
81+
dtype = x.dtype
82+
else:
83+
if x.dtype != dtype:
84+
x = xp.astype(dtype)
85+
86+
if axis is None:
87+
if x.ndim > 1:
88+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
89+
axis = 0
90+
91+
# np.cumprod does not support include_initial
92+
if include_initial:
93+
if axis < 0:
94+
axis += x.ndim
95+
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis)
96+
return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device)
97+
98+
5099
def max(
51100
x: Array,
52101
/,

0 commit comments

Comments
 (0)