|
5 | 5 | _real_numeric_dtypes,
|
6 | 6 | _floating_dtypes,
|
7 | 7 | _numeric_dtypes,
|
| 8 | + _integer_dtypes |
8 | 9 | )
|
| 10 | +from . import _dtypes |
| 11 | +from . import _info |
9 | 12 | from ._array_object import Array
|
10 | 13 | from ._dtypes import float32, complex64
|
11 | 14 | from ._flags import requires_api_version, get_array_api_strict_flags
|
@@ -47,6 +50,52 @@ def cumulative_sum(
|
47 | 50 | x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
|
48 | 51 | return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
|
49 | 52 |
|
| 53 | + |
| 54 | +@requires_api_version('2024.12') |
| 55 | +def cumulative_prod( |
| 56 | + x: Array, |
| 57 | + /, |
| 58 | + *, |
| 59 | + axis: Optional[int] = None, |
| 60 | + dtype: Optional[Dtype] = None, |
| 61 | + include_initial: bool = False, |
| 62 | +) -> Array: |
| 63 | + if x.dtype not in _numeric_dtypes: |
| 64 | + raise TypeError("Only numeric dtypes are allowed in cumulative_prod") |
| 65 | + if x.ndim == 0: |
| 66 | + raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod") |
| 67 | + |
| 68 | + # TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance. |
| 69 | + if dtype is None: |
| 70 | + if x.dtype in _integer_dtypes: |
| 71 | + default_int = _info.__array_namespace_info__().default_dtypes()["integral"] |
| 72 | + if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int): |
| 73 | + if x.dtype in _dtypes._unsigned_integer_dtypes: |
| 74 | + # find the unsigned integer of the same width as `default_int` |
| 75 | + dtype = _dtypes._get_unsigned_from_signed(default_int) |
| 76 | + else: |
| 77 | + dtype = default_int |
| 78 | + else: |
| 79 | + dtype = x.dtype |
| 80 | + else: |
| 81 | + dtype = x.dtype |
| 82 | + else: |
| 83 | + if x.dtype != dtype: |
| 84 | + x = xp.astype(dtype) |
| 85 | + |
| 86 | + if axis is None: |
| 87 | + if x.ndim > 1: |
| 88 | + raise ValueError("axis must be specified in cumulative_prod for more than one dimension") |
| 89 | + axis = 0 |
| 90 | + |
| 91 | + # np.cumprod does not support include_initial |
| 92 | + if include_initial: |
| 93 | + if axis < 0: |
| 94 | + axis += x.ndim |
| 95 | + x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis) |
| 96 | + return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device) |
| 97 | + |
| 98 | + |
50 | 99 | def max(
|
51 | 100 | x: Array,
|
52 | 101 | /,
|
|
0 commit comments