Skip to content

Commit dd44814

Browse files
committed
Add a wrapper for sign for NumPy-likes
Fixes data-apis#183
1 parent 5affae5 commit dd44814

File tree

4 files changed

+15
-2
lines changed

4 files changed

+15
-2
lines changed

array_api_compat/common/_aliases.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,22 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
530530
raise ValueError("Input array must be at least 1-d.")
531531
return tuple(xp.moveaxis(x, axis, 0))
532532

533+
# numpy 1.26 does not use the standard definition for sign on complex numbers
534+
535+
def sign(x: array, /, xp, **kwargs) -> array:
536+
if isdtype(x.dtype, 'complex floating', xp=xp):
537+
out = (x/xp.abs(x, **kwargs))[...]
538+
# sign(0) = 0 but the above formula would give nan
539+
out[x == 0+0j] = 0+0j
540+
return out[()]
541+
else:
542+
return xp.sign(x, **kwargs)
543+
533544
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
534545
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
535546
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
536547
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
537548
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
538549
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
539550
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
540-
'unstack']
551+
'unstack', 'sign']

array_api_compat/cupy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
matmul = get_xp(cp)(_aliases.matmul)
6363
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6464
tensordot = get_xp(cp)(_aliases.tensordot)
65+
sign = get_xp(cp)(_aliases.sign)
6566

6667
_copy_default = object()
6768

array_api_compat/dask/array/_aliases.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _dask_arange(
104104
trunc = get_xp(np)(_aliases.trunc)
105105
matmul = get_xp(np)(_aliases.matmul)
106106
tensordot = get_xp(np)(_aliases.tensordot)
107-
107+
sign = get_xp(np)(_aliases.sign)
108108

109109
# asarray also adds the copy keyword, which is not present in numpy 1.0.
110110
def asarray(

array_api_compat/numpy/_aliases.py

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
matmul = get_xp(np)(_aliases.matmul)
6363
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6464
tensordot = get_xp(np)(_aliases.tensordot)
65+
sign = get_xp(np)(_aliases.sign)
6566

6667
def _supports_buffer_protocol(obj):
6768
try:

0 commit comments

Comments
 (0)