Skip to content

Commit 3ff4ca6

Browse files
committed
MAINT: simplify cumulative_prod
1 parent 912362c commit 3ff4ca6

File tree

2 files changed

+5
-51
lines changed

2 files changed

+5
-51
lines changed

array_api_strict/_dtypes.py

-28
Original file line numberDiff line numberDiff line change
@@ -127,34 +127,6 @@ def __hash__(self):
127127
}
128128

129129

130-
def _bit_width(dtype):
131-
"""The bit width of an integer dtype"""
132-
if dtype == int8 or dtype == uint8:
133-
return 8
134-
elif dtype == int16 or dtype == uint16:
135-
return 16
136-
elif dtype == int32 or dtype == uint32:
137-
return 32
138-
elif dtype == int64 or dtype == uint64:
139-
return 64
140-
else:
141-
raise ValueError(f"_bit_width: {dtype = } not understood.")
142-
143-
144-
def _get_unsigned_from_signed(dtype):
145-
"""Return an unsigned integral dtype to match the input dtype."""
146-
if dtype == int8:
147-
return uint8
148-
elif dtype == int16:
149-
return uint16
150-
elif dtype == int32:
151-
return uint32
152-
elif dtype == int64:
153-
return uint64
154-
else:
155-
raise ValueError(f"_unsigned_from_signed: {dtype = } not understood.")
156-
157-
158130
# Note: the spec defines a restricted type promotion table compared to NumPy.
159131
# In particular, cross-kind promotions like integer + float or boolean +
160132
# integer are not allowed, even for functions that accept both kinds.

array_api_strict/_statistical_functions.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
_real_numeric_dtypes,
66
_floating_dtypes,
77
_numeric_dtypes,
8-
_integer_dtypes
98
)
10-
from . import _dtypes
11-
from . import _info
129
from ._array_object import Array
1310
from ._dtypes import float32, complex64
1411
from ._flags import requires_api_version, get_array_api_strict_flags
15-
from ._creation_functions import zeros
12+
from ._creation_functions import zeros, ones
1613
from ._manipulation_functions import concat
1714

1815
from typing import TYPE_CHECKING
@@ -65,23 +62,8 @@ def cumulative_prod(
6562
if x.ndim == 0:
6663
raise ValueError("Only ndim >= 1 arrays are allowed in cumulative_prod")
6764

68-
# TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance.
69-
if dtype is None:
70-
if x.dtype in _integer_dtypes:
71-
default_int = _info.__array_namespace_info__().default_dtypes()["integral"]
72-
if _dtypes._bit_width(x.dtype) < _dtypes._bit_width(default_int):
73-
if x.dtype in _dtypes._unsigned_integer_dtypes:
74-
# find the unsigned integer of the same width as `default_int`
75-
dtype = _dtypes._get_unsigned_from_signed(default_int)
76-
else:
77-
dtype = default_int
78-
else:
79-
dtype = x.dtype
80-
else:
81-
dtype = x.dtype
82-
else:
83-
if x.dtype != dtype:
84-
x = xp.astype(dtype)
65+
if dtype is not None:
66+
dtype = dtype._np_dtype
8567

8668
if axis is None:
8769
if x.ndim > 1:
@@ -92,8 +74,8 @@ def cumulative_prod(
9274
if include_initial:
9375
if axis < 0:
9476
axis += x.ndim
95-
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dtype), x], axis=axis)
96-
return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype._np_dtype), device=x.device)
77+
x = concat([ones(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
78+
return Array._new(np.cumprod(x._array, axis=axis, dtype=dtype), device=x.device)
9779

9880

9981
def max(

0 commit comments

Comments
 (0)