55 _real_numeric_dtypes ,
66 _floating_dtypes ,
77 _numeric_dtypes ,
8- _integer_dtypes
98)
10- from . import _dtypes
11- from . import _info
129from ._array_object import Array
1310from ._dtypes import float32 , complex64
1411from ._flags import requires_api_version , get_array_api_strict_flags
15- from ._creation_functions import zeros
12+ from ._creation_functions import zeros , ones
1613from ._manipulation_functions import concat
1714
1815from typing import TYPE_CHECKING
@@ -65,23 +62,8 @@ def cumulative_prod(
6562 if x .ndim == 0 :
6663 raise ValueError ("Only ndim >= 1 arrays are allowed in cumulative_prod" )
6764
68- # TODO: either all this is done by numpy's cumprod (?), or cumulative_sum should follow the same dance.
69- if dtype is None :
70- if x .dtype in _integer_dtypes :
71- default_int = _info .__array_namespace_info__ ().default_dtypes ()["integral" ]
72- if _dtypes ._bit_width (x .dtype ) < _dtypes ._bit_width (default_int ):
73- if x .dtype in _dtypes ._unsigned_integer_dtypes :
74- # find the unsigned integer of the same width as `default_int`
75- dtype = _dtypes ._get_unsigned_from_signed (default_int )
76- else :
77- dtype = default_int
78- else :
79- dtype = x .dtype
80- else :
81- dtype = x .dtype
82- else :
83- if x .dtype != dtype :
84- x = xp .astype (dtype )
65+ if dtype is not None :
66+ dtype = dtype ._np_dtype
8567
8668 if axis is None :
8769 if x .ndim > 1 :
@@ -92,8 +74,8 @@ def cumulative_prod(
9274 if include_initial :
9375 if axis < 0 :
9476 axis += x .ndim
95- x = concat ([ones (x .shape [:axis ] + (1 ,) + x .shape [axis + 1 :], dtype = dtype ), x ], axis = axis )
96- return Array ._new (np .cumprod (x ._array , axis = axis , dtype = dtype . _np_dtype ), device = x .device )
77+ x = concat ([ones (x .shape [:axis ] + (1 ,) + x .shape [axis + 1 :], dtype = x . dtype ), x ], axis = axis )
78+ return Array ._new (np .cumprod (x ._array , axis = axis , dtype = dtype ), device = x .device )
9779
9880
9981def max (
0 commit comments