Skip to content

Commit 2539057

Browse files
committed
Fix ruff errors
Ensure nan propagation is still handled correctly for CuPy sign().
1 parent dd44814 commit 2539057

File tree

2 files changed

+6
-10
lines changed

2 files changed

+6
-10
lines changed

array_api_compat/common/_aliases.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -532,14 +532,17 @@ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
532532

533533
# numpy 1.26 does not use the standard definition for sign on complex numbers
534534

535-
def sign(x: array, /, xp, **kwargs) -> array:
535+
def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
536536
if isdtype(x.dtype, 'complex floating', xp=xp):
537537
out = (x/xp.abs(x, **kwargs))[...]
538538
# sign(0) = 0 but the above formula would give nan
539539
out[x == 0+0j] = 0+0j
540-
return out[()]
541540
else:
542-
return xp.sign(x, **kwargs)
541+
out = xp.sign(x, **kwargs)
542+
# CuPy sign() does not propagate nans. See
543+
# https://github.com/data-apis/array-api-compat/issues/136
544+
out[xp.isnan(x)] = xp.nan
545+
return out[()]
543546

544547
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
545548
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',

array_api_compat/cupy/_aliases.py

-7
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,6 @@ def asarray(
110110

111111
return cp.array(obj, dtype=dtype, **kwargs)
112112

113-
def sign(x: ndarray, /) -> ndarray:
114-
# CuPy sign() does not propagate nans. See
115-
# https://github.com/data-apis/array-api-compat/issues/136
116-
out = cp.sign(x)
117-
out[cp.isnan(x)] = cp.nan
118-
return out
119-
120113
# These functions are completely new here. If the library already has them
121114
# (i.e., numpy 2.0), use the library version instead of our wrapper.
122115
if hasattr(cp, 'vecdot'):

0 commit comments

Comments
 (0)