Skip to content

Commit 0734064

Browse files
committed
Remove floating-point promotion from sum, prod, and trace
Fixes #152
1 parent d57c671 commit 0734064

File tree

5 files changed

+3
-50
lines changed

5 files changed

+3
-50
lines changed

array_api_compat/common/_aliases.py

+3-39
Original file line numberDiff line numberDiff line change
@@ -389,42 +389,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
389389
raise ValueError("nonzero() does not support zero-dimensional arrays")
390390
return xp.nonzero(x, **kwargs)
391391

392-
# sum() and prod() should always upcast when dtype=None
393-
def sum(
394-
x: ndarray,
395-
/,
396-
xp,
397-
*,
398-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
399-
dtype: Optional[Dtype] = None,
400-
keepdims: bool = False,
401-
**kwargs,
402-
) -> ndarray:
403-
# `xp.sum` already upcasts integers, but not floats or complexes
404-
if dtype is None:
405-
if x.dtype == xp.float32:
406-
dtype = xp.float64
407-
elif x.dtype == xp.complex64:
408-
dtype = xp.complex128
409-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
410-
411-
def prod(
412-
x: ndarray,
413-
/,
414-
xp,
415-
*,
416-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
417-
dtype: Optional[Dtype] = None,
418-
keepdims: bool = False,
419-
**kwargs,
420-
) -> ndarray:
421-
if dtype is None:
422-
if x.dtype == xp.float32:
423-
dtype = xp.float64
424-
elif x.dtype == xp.complex64:
425-
dtype = xp.complex128
426-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
427-
428392
# ceil, floor, and trunc return integers for integer inputs
429393

430394
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
@@ -525,6 +489,6 @@ def isdtype(
525489
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
526490
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527491
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528-
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
529-
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
530-
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
492+
'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape',
493+
'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul',
494+
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_linalg.py

-5
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
147147
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
148148

149149
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
150-
if dtype is None:
151-
if x.dtype == xp.float32:
152-
dtype = xp.float64
153-
elif x.dtype == xp.complex64:
154-
dtype = xp.complex128
155150
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
156151

157152
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',

array_api_compat/cupy/_aliases.py

-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@
5353
argsort = get_xp(cp)(_aliases.argsort)
5454
sort = get_xp(cp)(_aliases.sort)
5555
nonzero = get_xp(cp)(_aliases.nonzero)
56-
sum = get_xp(cp)(_aliases.sum)
57-
prod = get_xp(cp)(_aliases.prod)
5856
ceil = get_xp(cp)(_aliases.ceil)
5957
floor = get_xp(cp)(_aliases.floor)
6058
trunc = get_xp(cp)(_aliases.trunc)

array_api_compat/dask/array/_aliases.py

-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ def _dask_arange(
102102
vecdot = get_xp(da)(_aliases.vecdot)
103103

104104
nonzero = get_xp(da)(_aliases.nonzero)
105-
sum = get_xp(np)(_aliases.sum)
106-
prod = get_xp(np)(_aliases.prod)
107105
ceil = get_xp(np)(_aliases.ceil)
108106
floor = get_xp(np)(_aliases.floor)
109107
trunc = get_xp(np)(_aliases.trunc)

array_api_compat/numpy/_aliases.py

-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@
5353
argsort = get_xp(np)(_aliases.argsort)
5454
sort = get_xp(np)(_aliases.sort)
5555
nonzero = get_xp(np)(_aliases.nonzero)
56-
sum = get_xp(np)(_aliases.sum)
57-
prod = get_xp(np)(_aliases.prod)
5856
ceil = get_xp(np)(_aliases.ceil)
5957
floor = get_xp(np)(_aliases.floor)
6058
trunc = get_xp(np)(_aliases.trunc)

0 commit comments

Comments
 (0)